Refactor and optimize the flat type calculations (#4708)

* Optimize flat type representation calculations

Previously calculating the flat type representation would be done
recursively for an entire type tree every time it was visited.
Additionally the flat type representation was entirely built only to be
thrown away if it was too large at the end. This chiefly presented a
source of recursion based on the type structure in the component model
which fuzzing does not like as it reports stack overflows.

This commit overhauls the representation of flat types in Wasmtime by
caching the representation for each type in the compile-time
`ComponentTypesBuilder` structure. This avoids recalculating each time
the flat representation is queried and additionally allows opportunity
to have more short-circuiting to avoid building overly-large vectors.

* Remove duplicate flat count calculation in wasmtime

Roughly share the infrastructure in the `wasmtime-environ` crate, namely
the non-recursive and memoizing nature of the calculation.

* Fix component fuzz build

* Fix example compile
This commit is contained in:
Alex Crichton
2022-08-16 13:31:47 -05:00
committed by GitHub
parent 3c1490dd59
commit bc8e36a6af
14 changed files with 561 additions and 279 deletions

View File

@@ -1015,7 +1015,7 @@ fn expand_flags(flags: &Flags) -> Result<TokenStream> {
}); });
} }
FlagsSize::Size4Plus(n) => { FlagsSize::Size4Plus(n) => {
count = n; count = usize::from(n);
as_array = TokenStream::new(); as_array = TokenStream::new();
bitor = TokenStream::new(); bitor = TokenStream::new();
bitor_assign = TokenStream::new(); bitor_assign = TokenStream::new();
@@ -1072,7 +1072,7 @@ fn expand_flags(flags: &Flags) -> Result<TokenStream> {
.map(|i| { .map(|i| {
let field = format_ident!("__inner{}", i); let field = format_ident!("__inner{}", i);
let init = if index / 32 == i { let init = if index / 32 == usize::from(i) {
1_u32 << (index % 32) 1_u32 << (index % 32)
} else { } else {
0 0

View File

@@ -60,7 +60,7 @@ pub enum FlagsSize {
/// Flags can fit in a u16 /// Flags can fit in a u16
Size2, Size2,
/// Flags can fit in a specified number of u32 fields /// Flags can fit in a specified number of u32 fields
Size4Plus(usize), Size4Plus(u8),
} }
impl FlagsSize { impl FlagsSize {
@@ -73,7 +73,11 @@ impl FlagsSize {
} else if count <= 16 { } else if count <= 16 {
FlagsSize::Size2 FlagsSize::Size2
} else { } else {
FlagsSize::Size4Plus(ceiling_divide(count, 32)) let amt = ceiling_divide(count, 32);
if amt > (u8::MAX as usize) {
panic!("too many flags");
}
FlagsSize::Size4Plus(amt as u8)
} }
} }
} }

View File

@@ -174,7 +174,6 @@ impl Factc {
} }
types.pop_type_scope(); types.pop_type_scope();
let types = types.finish();
let mut fact_module = Module::new(&types, self.debug); let mut fact_module = Module::new(&types, self.debug);
for (i, adapter) in adapters.iter().enumerate() { for (i, adapter) in adapters.iter().enumerate() {
fact_module.adapt(&format!("adapter{i}"), adapter); fact_module.adapt(&format!("adapter{i}"), adapter);

View File

@@ -143,7 +143,6 @@ fn target(module: GenAdapterModule) {
types.pop_type_scope(); types.pop_type_scope();
} }
let types = types.finish();
let mut fact_module = Module::new(&types, module.debug); let mut fact_module = Module::new(&types, module.debug);
for (i, adapter) in adapters.iter().enumerate() { for (i, adapter) in adapters.iter().enumerate() {
fact_module.adapt(&format!("adapter{i}"), adapter); fact_module.adapt(&format!("adapter{i}"), adapter);

View File

@@ -184,10 +184,7 @@ impl<'data> Translator<'_, 'data> {
// the module using standard core wasm translation, and then fills out // the module using standard core wasm translation, and then fills out
// the dfg metadata for each adapter. // the dfg metadata for each adapter.
for (module_id, adapter_module) in state.adapter_modules.iter() { for (module_id, adapter_module) in state.adapter_modules.iter() {
let mut module = fact::Module::new( let mut module = fact::Module::new(self.types, self.tunables.debug_adapter_modules);
self.types.component_types(),
self.tunables.debug_adapter_modules,
);
let mut names = Vec::with_capacity(adapter_module.adapters.len()); let mut names = Vec::with_capacity(adapter_module.adapters.len());
for adapter in adapter_module.adapters.iter() { for adapter in adapter_module.adapters.iter() {
let name = format!("adapter{}", adapter.as_u32()); let name = format!("adapter{}", adapter.as_u32());

View File

@@ -1,3 +1,4 @@
use crate::component::{MAX_FLAT_PARAMS, MAX_FLAT_RESULTS};
use crate::{ use crate::{
EntityType, Global, GlobalInit, ModuleTypes, ModuleTypesBuilder, PrimaryMap, SignatureIndex, EntityType, Global, GlobalInit, ModuleTypes, ModuleTypesBuilder, PrimaryMap, SignatureIndex,
}; };
@@ -318,6 +319,11 @@ pub struct ComponentTypesBuilder {
component_types: ComponentTypes, component_types: ComponentTypes,
module_types: ModuleTypesBuilder, module_types: ModuleTypesBuilder,
// Cache of what the "flat" representation of all types are which is only
// used at compile-time and not used at runtime, hence the location here
// as opposed to `ComponentTypes`.
flat: FlatTypesCache,
} }
#[derive(Default)] #[derive(Default)]
@@ -326,6 +332,21 @@ struct TypeScope {
component: PrimaryMap<ComponentTypeIndex, TypeDef>, component: PrimaryMap<ComponentTypeIndex, TypeDef>,
} }
macro_rules! intern_and_fill_flat_types {
($me:ident, $name:ident, $val:ident) => {{
if let Some(idx) = $me.$name.get(&$val) {
return *idx;
}
let idx = $me.component_types.$name.push($val.clone());
let mut storage = FlatTypesStorage::new();
storage.$name($me, &$val);
let idx2 = $me.flat.$name.push(storage);
assert_eq!(idx, idx2);
$me.$name.insert($val, idx);
return idx;
}};
}
impl ComponentTypesBuilder { impl ComponentTypesBuilder {
/// Finishes this list of component types and returns the finished /// Finishes this list of component types and returns the finished
/// structure. /// structure.
@@ -769,42 +790,42 @@ impl ComponentTypesBuilder {
/// Interns a new record type within this type information. /// Interns a new record type within this type information.
pub fn add_record_type(&mut self, ty: TypeRecord) -> TypeRecordIndex { pub fn add_record_type(&mut self, ty: TypeRecord) -> TypeRecordIndex {
intern(&mut self.records, &mut self.component_types.records, ty) intern_and_fill_flat_types!(self, records, ty)
} }
/// Interns a new flags type within this type information. /// Interns a new flags type within this type information.
pub fn add_flags_type(&mut self, ty: TypeFlags) -> TypeFlagsIndex { pub fn add_flags_type(&mut self, ty: TypeFlags) -> TypeFlagsIndex {
intern(&mut self.flags, &mut self.component_types.flags, ty) intern_and_fill_flat_types!(self, flags, ty)
} }
/// Interns a new tuple type within this type information. /// Interns a new tuple type within this type information.
pub fn add_tuple_type(&mut self, ty: TypeTuple) -> TypeTupleIndex { pub fn add_tuple_type(&mut self, ty: TypeTuple) -> TypeTupleIndex {
intern(&mut self.tuples, &mut self.component_types.tuples, ty) intern_and_fill_flat_types!(self, tuples, ty)
} }
/// Interns a new variant type within this type information. /// Interns a new variant type within this type information.
pub fn add_variant_type(&mut self, ty: TypeVariant) -> TypeVariantIndex { pub fn add_variant_type(&mut self, ty: TypeVariant) -> TypeVariantIndex {
intern(&mut self.variants, &mut self.component_types.variants, ty) intern_and_fill_flat_types!(self, variants, ty)
} }
/// Interns a new union type within this type information. /// Interns a new union type within this type information.
pub fn add_union_type(&mut self, ty: TypeUnion) -> TypeUnionIndex { pub fn add_union_type(&mut self, ty: TypeUnion) -> TypeUnionIndex {
intern(&mut self.unions, &mut self.component_types.unions, ty) intern_and_fill_flat_types!(self, unions, ty)
} }
/// Interns a new enum type within this type information. /// Interns a new enum type within this type information.
pub fn add_enum_type(&mut self, ty: TypeEnum) -> TypeEnumIndex { pub fn add_enum_type(&mut self, ty: TypeEnum) -> TypeEnumIndex {
intern(&mut self.enums, &mut self.component_types.enums, ty) intern_and_fill_flat_types!(self, enums, ty)
} }
/// Interns a new option type within this type information. /// Interns a new option type within this type information.
pub fn add_option_type(&mut self, ty: TypeOption) -> TypeOptionIndex { pub fn add_option_type(&mut self, ty: TypeOption) -> TypeOptionIndex {
intern(&mut self.options, &mut self.component_types.options, ty) intern_and_fill_flat_types!(self, options, ty)
} }
/// Interns a new expected type within this type information. /// Interns a new expected type within this type information.
pub fn add_expected_type(&mut self, ty: TypeExpected) -> TypeExpectedIndex { pub fn add_expected_type(&mut self, ty: TypeExpected) -> TypeExpectedIndex {
intern(&mut self.expecteds, &mut self.component_types.expecteds, ty) intern_and_fill_flat_types!(self, expecteds, ty)
} }
/// Interns a new expected type within this type information. /// Interns a new expected type within this type information.
@@ -815,6 +836,43 @@ impl ComponentTypesBuilder {
ty, ty,
) )
} }
/// Returns the canonical ABI information about the specified type.
pub fn canonical_abi(&self, ty: &InterfaceType) -> &CanonicalAbiInfo {
self.component_types.canonical_abi(ty)
}
/// Returns the "flat types" for the given interface type used in the
/// canonical ABI.
///
/// Returns `None` if the type is too large to be represented via flat types
/// in the canonical abi.
pub fn flat_types(&self, ty: &InterfaceType) -> Option<FlatTypes<'_>> {
match ty {
InterfaceType::Unit => Some(FlatTypes::EMPTY),
InterfaceType::U8
| InterfaceType::S8
| InterfaceType::Bool
| InterfaceType::U16
| InterfaceType::S16
| InterfaceType::U32
| InterfaceType::S32
| InterfaceType::Char => Some(FlatTypes::I32),
InterfaceType::U64 | InterfaceType::S64 => Some(FlatTypes::I64),
InterfaceType::Float32 => Some(FlatTypes::F32),
InterfaceType::Float64 => Some(FlatTypes::F64),
InterfaceType::String | InterfaceType::List(_) => Some(FlatTypes::POINTER_PAIR),
InterfaceType::Record(i) => self.flat.records[*i].as_flat_types(),
InterfaceType::Variant(i) => self.flat.variants[*i].as_flat_types(),
InterfaceType::Tuple(i) => self.flat.tuples[*i].as_flat_types(),
InterfaceType::Flags(i) => self.flat.flags[*i].as_flat_types(),
InterfaceType::Enum(i) => self.flat.enums[*i].as_flat_types(),
InterfaceType::Union(i) => self.flat.unions[*i].as_flat_types(),
InterfaceType::Option(i) => self.flat.options[*i].as_flat_types(),
InterfaceType::Expected(i) => self.flat.expecteds[*i].as_flat_types(),
}
}
} }
// Forward the indexing impl to the internal `TypeTables` // Forward the indexing impl to the internal `TypeTables`
@@ -986,6 +1044,13 @@ pub struct CanonicalAbiInfo {
pub size64: u32, pub size64: u32,
/// The byte-alignment of this type in a 64-bit memory. /// The byte-alignment of this type in a 64-bit memory.
pub align64: u32, pub align64: u32,
/// The number of types it takes to represents this type in the "flat"
/// representation of the canonical abi where everything is passed as
/// immediate arguments or results.
///
/// If this is `None` then this type is not representable in the flat ABI
/// because it is too large.
pub flat_count: Option<u8>,
} }
impl Default for CanonicalAbiInfo { impl Default for CanonicalAbiInfo {
@@ -995,6 +1060,7 @@ impl Default for CanonicalAbiInfo {
align32: 1, align32: 1,
size64: 0, size64: 0,
align64: 1, align64: 1,
flat_count: Some(0),
} }
} }
} }
@@ -1019,6 +1085,7 @@ impl CanonicalAbiInfo {
align32: 1, align32: 1,
size64: 0, size64: 0,
align64: 1, align64: 1,
flat_count: Some(0),
}; };
/// ABI information for one-byte scalars. /// ABI information for one-byte scalars.
@@ -1036,6 +1103,7 @@ impl CanonicalAbiInfo {
align32: size, align32: size,
size64: size, size64: size,
align64: size, align64: size,
flat_count: Some(1),
} }
} }
@@ -1045,6 +1113,7 @@ impl CanonicalAbiInfo {
align32: 4, align32: 4,
size64: 16, size64: 16,
align64: 8, align64: 8,
flat_count: Some(2),
}; };
/// Returns the abi for a record represented by the specified fields. /// Returns the abi for a record represented by the specified fields.
@@ -1058,6 +1127,7 @@ impl CanonicalAbiInfo {
ret.align32 = ret.align32.max(field.align32); ret.align32 = ret.align32.max(field.align32);
ret.size64 = align_to(ret.size64, field.align64) + field.size64; ret.size64 = align_to(ret.size64, field.align64) + field.size64;
ret.align64 = ret.align64.max(field.align64); ret.align64 = ret.align64.max(field.align64);
ret.flat_count = add_flat(ret.flat_count, field.flat_count);
} }
ret.size32 = align_to(ret.size32, ret.align32); ret.size32 = align_to(ret.size32, ret.align32);
ret.size64 = align_to(ret.size64, ret.align64); ret.size64 = align_to(ret.size64, ret.align64);
@@ -1077,6 +1147,7 @@ impl CanonicalAbiInfo {
ret.align32 = max(ret.align32, field.align32); ret.align32 = max(ret.align32, field.align32);
ret.size64 = align_to(ret.size64, field.align64) + field.size64; ret.size64 = align_to(ret.size64, field.align64) + field.size64;
ret.align64 = max(ret.align64, field.align64); ret.align64 = max(ret.align64, field.align64);
ret.flat_count = add_flat(ret.flat_count, field.flat_count);
i += 1; i += 1;
} }
ret.size32 = align_to(ret.size32, ret.align32); ret.size32 = align_to(ret.size32, ret.align32);
@@ -1116,17 +1187,18 @@ impl CanonicalAbiInfo {
/// Returns ABI information for a structure which contains `count` flags. /// Returns ABI information for a structure which contains `count` flags.
pub const fn flags(count: usize) -> CanonicalAbiInfo { pub const fn flags(count: usize) -> CanonicalAbiInfo {
let (size, align) = match FlagsSize::from_count(count) { let (size, align, flat_count) = match FlagsSize::from_count(count) {
FlagsSize::Size0 => (0, 1), FlagsSize::Size0 => (0, 1, 0),
FlagsSize::Size1 => (1, 1), FlagsSize::Size1 => (1, 1, 1),
FlagsSize::Size2 => (2, 2), FlagsSize::Size2 => (2, 2, 1),
FlagsSize::Size4Plus(n) => ((n as u32) * 4, 4), FlagsSize::Size4Plus(n) => ((n as u32) * 4, 4, n),
}; };
CanonicalAbiInfo { CanonicalAbiInfo {
size32: size, size32: size,
align32: align, align32: align,
size64: size, size64: size,
align64: align, align64: align,
flat_count: Some(flat_count),
} }
} }
@@ -1144,11 +1216,13 @@ impl CanonicalAbiInfo {
let mut max_align32 = discrim_size; let mut max_align32 = discrim_size;
let mut max_size64 = 0; let mut max_size64 = 0;
let mut max_align64 = discrim_size; let mut max_align64 = discrim_size;
let mut max_case_count = Some(0);
for case in cases { for case in cases {
max_size32 = max_size32.max(case.size32); max_size32 = max_size32.max(case.size32);
max_align32 = max_align32.max(case.align32); max_align32 = max_align32.max(case.align32);
max_size64 = max_size64.max(case.size64); max_size64 = max_size64.max(case.size64);
max_align64 = max_align64.max(case.align64); max_align64 = max_align64.max(case.align64);
max_case_count = max_flat(max_case_count, case.flat_count);
} }
CanonicalAbiInfo { CanonicalAbiInfo {
size32: align_to( size32: align_to(
@@ -1161,6 +1235,7 @@ impl CanonicalAbiInfo {
max_align64, max_align64,
), ),
align64: max_align64, align64: max_align64,
flat_count: add_flat(max_case_count, Some(1)),
} }
} }
@@ -1177,6 +1252,7 @@ impl CanonicalAbiInfo {
let mut max_align32 = discrim_size; let mut max_align32 = discrim_size;
let mut max_size64 = 0; let mut max_size64 = 0;
let mut max_align64 = discrim_size; let mut max_align64 = discrim_size;
let mut max_case_count = Some(0);
let mut i = 0; let mut i = 0;
while i < cases.len() { while i < cases.len() {
let case = &cases[i]; let case = &cases[i];
@@ -1184,6 +1260,7 @@ impl CanonicalAbiInfo {
max_align32 = max(max_align32, case.align32); max_align32 = max(max_align32, case.align32);
max_size64 = max(max_size64, case.size64); max_size64 = max(max_size64, case.size64);
max_align64 = max(max_align64, case.align64); max_align64 = max(max_align64, case.align64);
max_case_count = max_flat(max_case_count, case.flat_count);
i += 1; i += 1;
} }
CanonicalAbiInfo { CanonicalAbiInfo {
@@ -1197,6 +1274,18 @@ impl CanonicalAbiInfo {
max_align64, max_align64,
), ),
align64: max_align64, align64: max_align64,
flat_count: add_flat(max_case_count, Some(1)),
}
}
/// Returns the flat count of this ABI information so long as the count
/// doesn't exceed the `max` specified.
pub fn flat_count(&self, max: usize) -> Option<usize> {
let flat = usize::from(self.flat_count?);
if flat > max {
None
} else {
Some(flat)
} }
} }
} }
@@ -1396,3 +1485,284 @@ pub struct TypeExpected {
/// Byte information about this variant type. /// Byte information about this variant type.
pub info: VariantInfo, pub info: VariantInfo,
} }
const MAX_FLAT_TYPES: usize = if MAX_FLAT_PARAMS > MAX_FLAT_RESULTS {
MAX_FLAT_PARAMS
} else {
MAX_FLAT_RESULTS
};
const fn add_flat(a: Option<u8>, b: Option<u8>) -> Option<u8> {
const MAX: u8 = MAX_FLAT_TYPES as u8;
let sum = match (a, b) {
(Some(a), Some(b)) => match a.checked_add(b) {
Some(c) => c,
None => return None,
},
_ => return None,
};
if sum > MAX {
None
} else {
Some(sum)
}
}
const fn max_flat(a: Option<u8>, b: Option<u8>) -> Option<u8> {
match (a, b) {
(Some(a), Some(b)) => {
if a > b {
Some(a)
} else {
Some(b)
}
}
_ => None,
}
}
/// Flat representation of a type in just core wasm types.
pub struct FlatTypes<'a> {
/// The flat representation of this type in 32-bit memories.
pub memory32: &'a [FlatType],
/// The flat representation of this type in 64-bit memories.
pub memory64: &'a [FlatType],
}
#[allow(missing_docs)]
impl FlatTypes<'_> {
pub const EMPTY: FlatTypes<'static> = FlatTypes::new(&[]);
pub const I32: FlatTypes<'static> = FlatTypes::new(&[FlatType::I32]);
pub const I64: FlatTypes<'static> = FlatTypes::new(&[FlatType::I64]);
pub const F32: FlatTypes<'static> = FlatTypes::new(&[FlatType::F32]);
pub const F64: FlatTypes<'static> = FlatTypes::new(&[FlatType::F64]);
pub const POINTER_PAIR: FlatTypes<'static> = FlatTypes {
memory32: &[FlatType::I32, FlatType::I32],
memory64: &[FlatType::I64, FlatType::I64],
};
const fn new(flat: &[FlatType]) -> FlatTypes<'_> {
FlatTypes {
memory32: flat,
memory64: flat,
}
}
/// Returns the number of flat types used to represent this type.
///
/// Note that this length is the same regardless to the size of memory.
pub fn len(&self) -> usize {
assert_eq!(self.memory32.len(), self.memory64.len());
self.memory32.len()
}
}
// Note that this is intentionally duplicated here to keep the size to 1 byte
// irregardless to changes in the core wasm type system since this will only
// ever use integers/floats for the forseeable future.
#[derive(PartialEq, Eq, Copy, Clone)]
#[allow(missing_docs)]
pub enum FlatType {
I32,
I64,
F32,
F64,
}
#[derive(Default)]
struct FlatTypesCache {
records: PrimaryMap<TypeRecordIndex, FlatTypesStorage>,
variants: PrimaryMap<TypeVariantIndex, FlatTypesStorage>,
tuples: PrimaryMap<TypeTupleIndex, FlatTypesStorage>,
enums: PrimaryMap<TypeEnumIndex, FlatTypesStorage>,
flags: PrimaryMap<TypeFlagsIndex, FlatTypesStorage>,
unions: PrimaryMap<TypeUnionIndex, FlatTypesStorage>,
options: PrimaryMap<TypeOptionIndex, FlatTypesStorage>,
expecteds: PrimaryMap<TypeExpectedIndex, FlatTypesStorage>,
}
struct FlatTypesStorage {
// This could be represented as `Vec<FlatType>` but on 64-bit architectures
// that's 24 bytes. Otherwise `FlatType` is 1 byte large and
// `MAX_FLAT_TYPES` is 16, so it should ideally be more space-efficient to
// use a flat array instead of a heap-based vector.
memory32: [FlatType; MAX_FLAT_TYPES],
memory64: [FlatType; MAX_FLAT_TYPES],
// Tracks the number of flat types pushed into this storage. If this is
// `MAX_FLAT_TYPES + 1` then this storage represents an un-reprsentable
// type in flat types.
len: u8,
}
impl FlatTypesStorage {
fn new() -> FlatTypesStorage {
FlatTypesStorage {
memory32: [FlatType::I32; MAX_FLAT_TYPES],
memory64: [FlatType::I32; MAX_FLAT_TYPES],
len: 0,
}
}
fn as_flat_types(&self) -> Option<FlatTypes<'_>> {
let len = usize::from(self.len);
if len > MAX_FLAT_TYPES {
assert_eq!(len, MAX_FLAT_TYPES + 1);
None
} else {
Some(FlatTypes {
memory32: &self.memory32[..len],
memory64: &self.memory64[..len],
})
}
}
/// Pushes a new flat type into this list using `t32` for 32-bit memories
/// and `t64` for 64-bit memories.
///
/// Returns whether the type was actually pushed or whether this list of
/// flat types just exceeded the maximum meaning that it is now
/// unrepresentable with a flat list of types.
fn push(&mut self, t32: FlatType, t64: FlatType) -> bool {
let len = usize::from(self.len);
if len < MAX_FLAT_TYPES {
self.memory32[len] = t32;
self.memory64[len] = t64;
self.len += 1;
true
} else {
// If this was the first one to go over then flag the length as
// being incompatible with a flat representation.
if len == MAX_FLAT_TYPES {
self.len += 1;
}
false
}
}
/// Builds up all flat types internally using the specified representation
/// for all of the component fields of the record.
fn build_record<'a>(&mut self, types: impl Iterator<Item = Option<FlatTypes<'a>>>) {
for ty in types {
let types = match ty {
Some(types) => types,
None => {
self.len = u8::try_from(MAX_FLAT_TYPES + 1).unwrap();
return;
}
};
for (t32, t64) in types.memory32.iter().zip(types.memory64) {
if !self.push(*t32, *t64) {
return;
}
}
}
}
/// Builds up the flat types used to represent a `variant` which notably
/// handles "join"ing types together so each case is representable as a
/// single flat list of types.
fn build_variant<'a, I>(&mut self, cases: I)
where
I: IntoIterator<Item = Option<FlatTypes<'a>>>,
{
let cases = cases.into_iter();
self.push(FlatType::I32, FlatType::I32);
for ty in cases {
let types = match ty {
Some(types) => types,
// If this case isn't representable with a flat list of types
// then this variant also isn't representable.
None => {
self.len = u8::try_from(MAX_FLAT_TYPES + 1).unwrap();
return;
}
};
// If the case used all of the flat types then the discriminant
// added for this variant means that this variant is no longer
// representable.
if types.memory32.len() >= MAX_FLAT_TYPES {
self.len = u8::try_from(MAX_FLAT_TYPES + 1).unwrap();
return;
}
let dst = self.memory32.iter_mut().zip(&mut self.memory64).skip(1);
for (i, ((t32, t64), (dst32, dst64))) in types
.memory32
.iter()
.zip(types.memory64)
.zip(dst)
.enumerate()
{
if i + 1 < usize::from(self.len) {
// If this index hs already been set by some previous case
// then the types are joined together.
dst32.join(*t32);
dst64.join(*t64);
} else {
// Otherwise if this is the first time that the
// representation has gotten this large then the destination
// is simply whatever the type is. The length is also
// increased here to indicate this.
self.len += 1;
*dst32 = *t32;
*dst64 = *t64;
}
}
}
}
fn records(&mut self, types: &ComponentTypesBuilder, ty: &TypeRecord) {
self.build_record(ty.fields.iter().map(|f| types.flat_types(&f.ty)));
}
fn tuples(&mut self, types: &ComponentTypesBuilder, ty: &TypeTuple) {
self.build_record(ty.types.iter().map(|t| types.flat_types(t)));
}
fn enums(&mut self, _types: &ComponentTypesBuilder, _ty: &TypeEnum) {
self.push(FlatType::I32, FlatType::I32);
}
fn flags(&mut self, _types: &ComponentTypesBuilder, ty: &TypeFlags) {
match FlagsSize::from_count(ty.names.len()) {
FlagsSize::Size0 => {}
FlagsSize::Size1 | FlagsSize::Size2 => {
self.push(FlatType::I32, FlatType::I32);
}
FlagsSize::Size4Plus(n) => {
for _ in 0..n {
self.push(FlatType::I32, FlatType::I32);
}
}
}
}
fn variants(&mut self, types: &ComponentTypesBuilder, ty: &TypeVariant) {
self.build_variant(ty.cases.iter().map(|c| types.flat_types(&c.ty)))
}
fn unions(&mut self, types: &ComponentTypesBuilder, ty: &TypeUnion) {
self.build_variant(ty.types.iter().map(|t| types.flat_types(t)))
}
fn expecteds(&mut self, types: &ComponentTypesBuilder, ty: &TypeExpected) {
self.build_variant([types.flat_types(&ty.ok), types.flat_types(&ty.err)]);
}
fn options(&mut self, types: &ComponentTypesBuilder, ty: &TypeOption) {
self.build_variant([Some(FlatTypes::EMPTY), types.flat_types(&ty.ty)]);
}
}
impl FlatType {
fn join(&mut self, other: FlatType) {
if *self == other {
return;
}
*self = match (*self, other) {
(FlatType::I32, FlatType::F32) | (FlatType::F32, FlatType::I32) => FlatType::I32,
_ => FlatType::I64,
};
}
}

View File

@@ -20,8 +20,8 @@
use crate::component::dfg::CoreDef; use crate::component::dfg::CoreDef;
use crate::component::{ use crate::component::{
Adapter, AdapterOptions as AdapterOptionsDfg, ComponentTypes, InterfaceType, StringEncoding, Adapter, AdapterOptions as AdapterOptionsDfg, ComponentTypesBuilder, InterfaceType,
TypeFuncIndex, StringEncoding, TypeFuncIndex,
}; };
use crate::fact::transcode::Transcoder; use crate::fact::transcode::Transcoder;
use crate::{EntityRef, FuncIndex, GlobalIndex, MemoryIndex, PrimaryMap}; use crate::{EntityRef, FuncIndex, GlobalIndex, MemoryIndex, PrimaryMap};
@@ -41,7 +41,7 @@ pub struct Module<'a> {
/// Whether or not debug code is inserted into the adapters themselves. /// Whether or not debug code is inserted into the adapters themselves.
debug: bool, debug: bool,
/// Type information from the creator of this `Module` /// Type information from the creator of this `Module`
types: &'a ComponentTypes, types: &'a ComponentTypesBuilder,
/// Core wasm type section that's incrementally built /// Core wasm type section that's incrementally built
core_types: core_types::CoreTypes, core_types: core_types::CoreTypes,
@@ -125,7 +125,7 @@ enum Context {
impl<'a> Module<'a> { impl<'a> Module<'a> {
/// Creates an empty module. /// Creates an empty module.
pub fn new(types: &'a ComponentTypes, debug: bool) -> Module<'a> { pub fn new(types: &'a ComponentTypesBuilder, debug: bool) -> Module<'a> {
Module { Module {
debug, debug,
types, types,

View File

@@ -1,9 +1,10 @@
//! Size, align, and flattening information about component model types. //! Size, align, and flattening information about component model types.
use crate::component::{ComponentTypes, InterfaceType, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS}; use crate::component::{
ComponentTypesBuilder, FlatType, InterfaceType, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS,
};
use crate::fact::{AdapterOptions, Context, Options}; use crate::fact::{AdapterOptions, Context, Options};
use wasm_encoder::ValType; use wasm_encoder::ValType;
use wasmtime_component_util::FlagsSize;
/// Metadata about a core wasm signature which is created for a component model /// Metadata about a core wasm signature which is created for a component model
/// signature. /// signature.
@@ -23,7 +24,7 @@ pub struct Signature {
pub results_indirect: bool, pub results_indirect: bool,
} }
impl ComponentTypes { impl ComponentTypesBuilder {
/// Calculates the core wasm function signature for the component function /// Calculates the core wasm function signature for the component function
/// type specified within `Context`. /// type specified within `Context`.
/// ///
@@ -34,31 +35,39 @@ impl ComponentTypes {
let ty = &self[options.ty]; let ty = &self[options.ty];
let ptr_ty = options.options.ptr(); let ptr_ty = options.options.ptr();
let mut params = self.flatten_types(&options.options, ty.params.iter().map(|(_, ty)| *ty));
let mut params_indirect = false; let mut params_indirect = false;
if params.len() > MAX_FLAT_PARAMS { let mut params = match self.flatten_types(
params = vec![ptr_ty]; &options.options,
MAX_FLAT_PARAMS,
ty.params.iter().map(|(_, ty)| *ty),
) {
Some(list) => list,
None => {
params_indirect = true; params_indirect = true;
vec![ptr_ty]
} }
};
let mut results = self.flatten_types(&options.options, [ty.result]);
let mut results_indirect = false; let mut results_indirect = false;
if results.len() > MAX_FLAT_RESULTS { let results = match self.flatten_types(&options.options, MAX_FLAT_RESULTS, [ty.result]) {
Some(list) => list,
None => {
results_indirect = true; results_indirect = true;
match context { match context {
// For a lifted function too-many-results gets translated to a // For a lifted function too-many-results gets translated to a
// returned pointer where results are read from. The callee // returned pointer where results are read from. The callee
// allocates space here. // allocates space here.
Context::Lift => results = vec![ptr_ty], Context::Lift => vec![ptr_ty],
// For a lowered function too-many-results becomes a return // For a lowered function too-many-results becomes a return
// pointer which is passed as the last argument. The caller // pointer which is passed as the last argument. The caller
// allocates space here. // allocates space here.
Context::Lower => { Context::Lower => {
results.truncate(0);
params.push(ptr_ty); params.push(ptr_ty);
Vec::new()
} }
} }
} }
};
Signature { Signature {
params, params,
results, results,
@@ -72,115 +81,31 @@ impl ComponentTypes {
pub(super) fn flatten_types( pub(super) fn flatten_types(
&self, &self,
opts: &Options, opts: &Options,
max: usize,
tys: impl IntoIterator<Item = InterfaceType>, tys: impl IntoIterator<Item = InterfaceType>,
) -> Vec<ValType> { ) -> Option<Vec<ValType>> {
let mut result = Vec::new(); let mut dst = Vec::new();
for ty in tys { for ty in tys {
self.push_flat(opts, &ty, &mut result); let flat = self.flat_types(&ty)?;
let types = if opts.memory64 {
flat.memory64
} else {
flat.memory32
};
for ty in types {
let ty = match ty {
FlatType::I32 => ValType::I32,
FlatType::I64 => ValType::I64,
FlatType::F32 => ValType::F32,
FlatType::F64 => ValType::F64,
};
if dst.len() == max {
return None;
} }
result dst.push(ty);
}
fn push_flat(&self, opts: &Options, ty: &InterfaceType, dst: &mut Vec<ValType>) {
match ty {
InterfaceType::Unit => {}
InterfaceType::Bool
| InterfaceType::S8
| InterfaceType::U8
| InterfaceType::S16
| InterfaceType::U16
| InterfaceType::S32
| InterfaceType::U32
| InterfaceType::Char => dst.push(ValType::I32),
InterfaceType::S64 | InterfaceType::U64 => dst.push(ValType::I64),
InterfaceType::Float32 => dst.push(ValType::F32),
InterfaceType::Float64 => dst.push(ValType::F64),
InterfaceType::String | InterfaceType::List(_) => {
dst.push(opts.ptr());
dst.push(opts.ptr());
}
InterfaceType::Record(r) => {
for field in self[*r].fields.iter() {
self.push_flat(opts, &field.ty, dst);
}
}
InterfaceType::Tuple(t) => {
for ty in self[*t].types.iter() {
self.push_flat(opts, ty, dst);
}
}
InterfaceType::Flags(f) => {
let flags = &self[*f];
match FlagsSize::from_count(flags.names.len()) {
FlagsSize::Size0 => {}
FlagsSize::Size1 | FlagsSize::Size2 => dst.push(ValType::I32),
FlagsSize::Size4Plus(n) => {
dst.extend((0..n).map(|_| ValType::I32));
}
}
}
InterfaceType::Enum(_) => dst.push(ValType::I32),
InterfaceType::Option(t) => {
dst.push(ValType::I32);
self.push_flat(opts, &self[*t].ty, dst);
}
InterfaceType::Variant(t) => {
dst.push(ValType::I32);
let pos = dst.len();
let mut tmp = Vec::new();
for case in self[*t].cases.iter() {
self.push_flat_variant(opts, &case.ty, pos, &mut tmp, dst);
}
}
InterfaceType::Union(t) => {
dst.push(ValType::I32);
let pos = dst.len();
let mut tmp = Vec::new();
for ty in self[*t].types.iter() {
self.push_flat_variant(opts, ty, pos, &mut tmp, dst);
}
}
InterfaceType::Expected(t) => {
dst.push(ValType::I32);
let e = &self[*t];
let pos = dst.len();
let mut tmp = Vec::new();
self.push_flat_variant(opts, &e.ok, pos, &mut tmp, dst);
self.push_flat_variant(opts, &e.err, pos, &mut tmp, dst);
}
}
}
fn push_flat_variant(
&self,
opts: &Options,
ty: &InterfaceType,
pos: usize,
tmp: &mut Vec<ValType>,
dst: &mut Vec<ValType>,
) {
tmp.truncate(0);
self.push_flat(opts, ty, tmp);
for (i, a) in tmp.iter().enumerate() {
match dst.get_mut(pos + i) {
Some(b) => join(*a, b),
None => dst.push(*a),
}
}
fn join(a: ValType, b: &mut ValType) {
if a == *b {
return;
}
match (a, *b) {
(ValType::I32, ValType::F32) | (ValType::F32, ValType::I32) => *b = ValType::I32,
_ => *b = ValType::I64,
} }
} }
Some(dst)
} }
pub(super) fn align(&self, opts: &Options, ty: &InterfaceType) -> u32 { pub(super) fn align(&self, opts: &Options, ty: &InterfaceType) -> u32 {

View File

@@ -16,7 +16,7 @@
//! can be somewhat arbitrary, an intentional decision. //! can be somewhat arbitrary, an intentional decision.
use crate::component::{ use crate::component::{
CanonicalAbiInfo, ComponentTypes, InterfaceType, StringEncoding, TypeEnumIndex, CanonicalAbiInfo, ComponentTypesBuilder, InterfaceType, StringEncoding, TypeEnumIndex,
TypeExpectedIndex, TypeFlagsIndex, TypeInterfaceIndex, TypeOptionIndex, TypeRecordIndex, TypeExpectedIndex, TypeFlagsIndex, TypeInterfaceIndex, TypeOptionIndex, TypeRecordIndex,
TypeTupleIndex, TypeUnionIndex, TypeVariantIndex, VariantInfo, FLAG_MAY_ENTER, FLAG_MAY_LEAVE, TypeTupleIndex, TypeUnionIndex, TypeVariantIndex, VariantInfo, FLAG_MAY_ENTER, FLAG_MAY_LEAVE,
MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS,
@@ -36,7 +36,7 @@ const MAX_STRING_BYTE_LENGTH: u32 = 1 << 31;
const UTF16_TAG: u32 = 1 << 31; const UTF16_TAG: u32 = 1 << 31;
struct Compiler<'a, 'b> { struct Compiler<'a, 'b> {
types: &'a ComponentTypes, types: &'a ComponentTypesBuilder,
module: &'b mut Module<'a>, module: &'b mut Module<'a>,
result: FunctionId, result: FunctionId,
@@ -279,14 +279,16 @@ impl Compiler<'_, '_> {
// TODO: handle subtyping // TODO: handle subtyping
assert_eq!(src_tys.len(), dst_tys.len()); assert_eq!(src_tys.len(), dst_tys.len());
let src_flat = self let src_flat =
.types self.types
.flatten_types(lower_opts, src_tys.iter().copied()); .flatten_types(lower_opts, MAX_FLAT_PARAMS, src_tys.iter().copied());
let dst_flat = self.types.flatten_types(lift_opts, dst_tys.iter().copied()); let dst_flat =
self.types
.flatten_types(lift_opts, MAX_FLAT_PARAMS, dst_tys.iter().copied());
let src = if src_flat.len() <= MAX_FLAT_PARAMS { let src = if let Some(flat) = &src_flat {
Source::Stack(Stack { Source::Stack(Stack {
locals: &param_locals[..src_flat.len()], locals: &param_locals[..flat.len()],
opts: lower_opts, opts: lower_opts,
}) })
} else { } else {
@@ -303,8 +305,8 @@ impl Compiler<'_, '_> {
Source::Memory(self.memory_operand(lower_opts, TempLocal::new(addr, ty), align)) Source::Memory(self.memory_operand(lower_opts, TempLocal::new(addr, ty), align))
}; };
let dst = if dst_flat.len() <= MAX_FLAT_PARAMS { let dst = if let Some(flat) = &dst_flat {
Destination::Stack(&dst_flat, lift_opts) Destination::Stack(flat, lift_opts)
} else { } else {
// If there are too many parameters then space is allocated in the // If there are too many parameters then space is allocated in the
// destination module for the parameters via its `realloc` function. // destination module for the parameters via its `realloc` function.
@@ -348,10 +350,14 @@ impl Compiler<'_, '_> {
let lift_opts = &adapter.lift.options; let lift_opts = &adapter.lift.options;
let lower_opts = &adapter.lower.options; let lower_opts = &adapter.lower.options;
let src_flat = self.types.flatten_types(lift_opts, [src_ty]); let src_flat = self
let dst_flat = self.types.flatten_types(lower_opts, [dst_ty]); .types
.flatten_types(lift_opts, MAX_FLAT_RESULTS, [src_ty]);
let dst_flat = self
.types
.flatten_types(lower_opts, MAX_FLAT_RESULTS, [dst_ty]);
let src = if src_flat.len() <= MAX_FLAT_RESULTS { let src = if src_flat.is_some() {
Source::Stack(Stack { Source::Stack(Stack {
locals: result_locals, locals: result_locals,
opts: lift_opts, opts: lift_opts,
@@ -368,8 +374,8 @@ impl Compiler<'_, '_> {
Source::Memory(self.memory_operand(lift_opts, TempLocal::new(addr, ty), align)) Source::Memory(self.memory_operand(lift_opts, TempLocal::new(addr, ty), align))
}; };
let dst = if dst_flat.len() <= MAX_FLAT_RESULTS { let dst = if let Some(flat) = &dst_flat {
Destination::Stack(&dst_flat, lower_opts) Destination::Stack(flat, lower_opts)
} else { } else {
// This is slightly different than `translate_params` where the // This is slightly different than `translate_params` where the
// return pointer was provided by the caller of this function // return pointer was provided by the caller of this function
@@ -1937,6 +1943,7 @@ impl Compiler<'_, '_> {
FlagsSize::Size4Plus(n) => { FlagsSize::Size4Plus(n) => {
let srcs = src.record_field_srcs(self.types, (0..n).map(|_| InterfaceType::U32)); let srcs = src.record_field_srcs(self.types, (0..n).map(|_| InterfaceType::U32));
let dsts = dst.record_field_dsts(self.types, (0..n).map(|_| InterfaceType::U32)); let dsts = dst.record_field_dsts(self.types, (0..n).map(|_| InterfaceType::U32));
let n = usize::from(n);
for (i, (src, dst)) in srcs.zip(dsts).enumerate() { for (i, (src, dst)) in srcs.zip(dsts).enumerate() {
let mask = if i == n - 1 && (cnt % 32 != 0) { let mask = if i == n - 1 && (cnt % 32 != 0) {
(1 << (cnt % 32)) - 1 (1 << (cnt % 32)) - 1
@@ -2775,7 +2782,7 @@ impl<'a> Source<'a> {
/// offset for each memory-based type. /// offset for each memory-based type.
fn record_field_srcs<'b>( fn record_field_srcs<'b>(
&'b self, &'b self,
types: &'b ComponentTypes, types: &'b ComponentTypesBuilder,
fields: impl IntoIterator<Item = InterfaceType> + 'b, fields: impl IntoIterator<Item = InterfaceType> + 'b,
) -> impl Iterator<Item = Source<'a>> + 'b ) -> impl Iterator<Item = Source<'a>> + 'b
where where
@@ -2788,7 +2795,7 @@ impl<'a> Source<'a> {
Source::Memory(mem) Source::Memory(mem)
} }
Source::Stack(stack) => { Source::Stack(stack) => {
let cnt = types.flatten_types(stack.opts, [ty]).len() as u32; let cnt = types.flat_types(&ty).unwrap().len() as u32;
offset += cnt; offset += cnt;
Source::Stack(stack.slice((offset - cnt) as usize..offset as usize)) Source::Stack(stack.slice((offset - cnt) as usize..offset as usize))
} }
@@ -2798,13 +2805,13 @@ impl<'a> Source<'a> {
/// Returns the corresponding discriminant source and payload source f /// Returns the corresponding discriminant source and payload source f
fn payload_src( fn payload_src(
&self, &self,
types: &ComponentTypes, types: &ComponentTypesBuilder,
info: &VariantInfo, info: &VariantInfo,
case: &InterfaceType, case: &InterfaceType,
) -> Source<'a> { ) -> Source<'a> {
match self { match self {
Source::Stack(s) => { Source::Stack(s) => {
let flat_len = types.flatten_types(s.opts, [*case]).len(); let flat_len = types.flat_types(case).unwrap().len();
Source::Stack(s.slice(1..s.locals.len()).slice(0..flat_len)) Source::Stack(s.slice(1..s.locals.len()).slice(0..flat_len))
} }
Source::Memory(mem) => { Source::Memory(mem) => {
@@ -2830,7 +2837,7 @@ impl<'a> Destination<'a> {
/// Same as `Source::record_field_srcs` but for destinations. /// Same as `Source::record_field_srcs` but for destinations.
fn record_field_dsts<'b>( fn record_field_dsts<'b>(
&'b self, &'b self,
types: &'b ComponentTypes, types: &'b ComponentTypesBuilder,
fields: impl IntoIterator<Item = InterfaceType> + 'b, fields: impl IntoIterator<Item = InterfaceType> + 'b,
) -> impl Iterator<Item = Destination> + 'b ) -> impl Iterator<Item = Destination> + 'b
where where
@@ -2843,7 +2850,7 @@ impl<'a> Destination<'a> {
Destination::Memory(mem) Destination::Memory(mem)
} }
Destination::Stack(s, opts) => { Destination::Stack(s, opts) => {
let cnt = types.flatten_types(opts, [ty]).len() as u32; let cnt = types.flat_types(&ty).unwrap().len() as u32;
offset += cnt; offset += cnt;
Destination::Stack(&s[(offset - cnt) as usize..offset as usize], opts) Destination::Stack(&s[(offset - cnt) as usize..offset as usize], opts)
} }
@@ -2853,13 +2860,13 @@ impl<'a> Destination<'a> {
/// Returns the corresponding discriminant source and payload source f /// Returns the corresponding discriminant source and payload source f
fn payload_dst( fn payload_dst(
&self, &self,
types: &ComponentTypes, types: &ComponentTypesBuilder,
info: &VariantInfo, info: &VariantInfo,
case: &InterfaceType, case: &InterfaceType,
) -> Destination { ) -> Destination {
match self { match self {
Destination::Stack(s, opts) => { Destination::Stack(s, opts) => {
let flat_len = types.flatten_types(opts, [*case]).len(); let flat_len = types.flat_types(case).unwrap().len();
Destination::Stack(&s[1..][..flat_len], opts) Destination::Stack(&s[1..][..flat_len], opts)
} }
Destination::Memory(mem) => { Destination::Memory(mem) => {
@@ -2883,7 +2890,7 @@ impl<'a> Destination<'a> {
fn next_field_offset<'a>( fn next_field_offset<'a>(
offset: &mut u32, offset: &mut u32,
types: &ComponentTypes, types: &ComponentTypesBuilder,
field: &InterfaceType, field: &InterfaceType,
mem: &Memory<'a>, mem: &Memory<'a>,
) -> Memory<'a> { ) -> Memory<'a> {
@@ -2930,7 +2937,7 @@ struct VariantCase<'a> {
dst_ty: &'a InterfaceType, dst_ty: &'a InterfaceType,
} }
fn variant_info<I>(types: &ComponentTypes, cases: I) -> VariantInfo fn variant_info<I>(types: &ComponentTypesBuilder, cases: I) -> VariantInfo
where where
I: IntoIterator<Item = InterfaceType>, I: IntoIterator<Item = InterfaceType>,
I::IntoIter: ExactSizeIterator, I::IntoIter: ExactSizeIterator,

View File

@@ -166,7 +166,7 @@ fn u32_count_from_flag_count(count: usize) -> usize {
match FlagsSize::from_count(count) { match FlagsSize::from_count(count) {
FlagsSize::Size0 => 0, FlagsSize::Size0 => 0,
FlagsSize::Size1 | FlagsSize::Size2 => 1, FlagsSize::Size1 | FlagsSize::Size2 => 1,
FlagsSize::Size4Plus(n) => n, FlagsSize::Size4Plus(n) => n.into(),
} }
} }
@@ -270,7 +270,7 @@ impl Type {
alignment: 2, alignment: 2,
}, },
FlagsSize::Size4Plus(n) => SizeAndAlignment { FlagsSize::Size4Plus(n) => SizeAndAlignment {
size: n * 4, size: usize::from(n) * 4,
alignment: 4, alignment: 4,
}, },
}, },

View File

@@ -302,15 +302,15 @@ impl Func {
result = Type::from(&ty.result, &data.types); result = Type::from(&ty.result, &data.types);
} }
let param_count = params.iter().map(|ty| ty.flatten_count()).sum::<usize>(); let param_abi = CanonicalAbiInfo::record(params.iter().map(|t| t.canonical_abi()));
let result_count = result.flatten_count(); let result_count = result.canonical_abi().flat_count(MAX_FLAT_RESULTS);
self.call_raw( self.call_raw(
store, store,
args, args,
|store, options, args, dst: &mut MaybeUninit<[ValRaw; MAX_FLAT_PARAMS]>| { |store, options, args, dst: &mut MaybeUninit<[ValRaw; MAX_FLAT_PARAMS]>| {
if param_count > MAX_FLAT_PARAMS { if param_abi.flat_count(MAX_FLAT_PARAMS).is_none() {
self.store_args(store, &options, &params, args, dst) self.store_args(store, &options, &param_abi, &params, args, dst)
} else { } else {
dst.write([ValRaw::u64(0); MAX_FLAT_PARAMS]); dst.write([ValRaw::u64(0); MAX_FLAT_PARAMS]);
@@ -324,7 +324,7 @@ impl Func {
} }
}, },
|store, options, src: &[ValRaw; MAX_FLAT_RESULTS]| { |store, options, src: &[ValRaw; MAX_FLAT_RESULTS]| {
if result_count > MAX_FLAT_RESULTS { if result_count.is_none() {
Self::load_result(&Memory::new(store, &options), &result, &mut src.iter()) Self::load_result(&Memory::new(store, &options), &result, &mut src.iter())
} else { } else {
Val::lift(&result, store, &options, &mut src.iter()) Val::lift(&result, store, &options, &mut src.iter())
@@ -554,12 +554,11 @@ impl Func {
&self, &self,
store: &mut StoreContextMut<'_, T>, store: &mut StoreContextMut<'_, T>,
options: &Options, options: &Options,
abi: &CanonicalAbiInfo,
params: &[Type], params: &[Type],
args: &[Val], args: &[Val],
dst: &mut MaybeUninit<[ValRaw; MAX_FLAT_PARAMS]>, dst: &mut MaybeUninit<[ValRaw; MAX_FLAT_PARAMS]>,
) -> Result<()> { ) -> Result<()> {
let abi = CanonicalAbiInfo::record(params.iter().map(|t| t.canonical_abi()));
let mut memory = MemoryMut::new(store.as_context_mut(), options); let mut memory = MemoryMut::new(store.as_context_mut(), options);
let size = usize::try_from(abi.size32).unwrap(); let size = usize::try_from(abi.size32).unwrap();
let ptr = memory.realloc(0, 0, abi.align32, size)?; let ptr = memory.realloc(0, 0, abi.align32, size)?;

View File

@@ -400,12 +400,19 @@ where
bail!("cannot leave component instance"); bail!("cannot leave component instance");
} }
let param_count = params.iter().map(|ty| ty.flatten_count()).sum::<usize>();
let args; let args;
let ret_index; let ret_index;
if param_count <= MAX_FLAT_PARAMS { let param_abi = CanonicalAbiInfo::record(params.iter().map(|t| t.canonical_abi()));
let param_count = param_abi.flat_count.and_then(|i| {
let i = usize::from(i);
if i > MAX_FLAT_PARAMS {
None
} else {
Some(i)
}
});
if let Some(param_count) = param_count {
let iter = &mut storage.iter(); let iter = &mut storage.iter();
args = params args = params
.iter() .iter()
@@ -413,8 +420,6 @@ where
.collect::<Result<Box<[_]>>>()?; .collect::<Result<Box<[_]>>>()?;
ret_index = param_count; ret_index = param_count;
} else { } else {
let param_abi = CanonicalAbiInfo::record(params.iter().map(|t| t.canonical_abi()));
let memory = Memory::new(cx.0, &options); let memory = Memory::new(cx.0, &options);
let mut offset = validate_inbounds_dynamic(&param_abi, memory.as_slice(), &storage[0])?; let mut offset = validate_inbounds_dynamic(&param_abi, memory.as_slice(), &storage[0])?;
args = params args = params
@@ -436,8 +441,8 @@ where
flags.set_may_leave(false); flags.set_may_leave(false);
result.check(&ret)?; result.check(&ret)?;
let result_count = result.flatten_count(); let result_count = result.canonical_abi().flat_count(MAX_FLAT_RESULTS);
if result_count <= MAX_FLAT_RESULTS { if result_count.is_some() {
let dst = mem::transmute::<&mut [ValRaw], &mut [MaybeUninit<ValRaw>]>(storage); let dst = mem::transmute::<&mut [ValRaw], &mut [MaybeUninit<ValRaw>]>(storage);
ret.lower(&mut cx, &options, &mut dst.iter_mut())?; ret.lower(&mut cx, &options, &mut dst.iter_mut())?;
} else { } else {

View File

@@ -78,6 +78,10 @@ impl Record {
ty: Type::from(&field.ty, &self.0.types), ty: Type::from(&field.ty, &self.0.types),
}) })
} }
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
&self.0.types[self.0.index].abi
}
} }
/// A `tuple` interface type /// A `tuple` interface type
@@ -97,6 +101,10 @@ impl Tuple {
.iter() .iter()
.map(|ty| Type::from(ty, &self.0.types)) .map(|ty| Type::from(ty, &self.0.types))
} }
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
&self.0.types[self.0.index].abi
}
} }
/// A case declaration belonging to a `variant` /// A case declaration belonging to a `variant`
@@ -128,6 +136,10 @@ impl Variant {
pub(crate) fn variant_info(&self) -> &VariantInfo { pub(crate) fn variant_info(&self) -> &VariantInfo {
&self.0.types[self.0.index].info &self.0.types[self.0.index].info
} }
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
&self.0.types[self.0.index].abi
}
} }
/// An `enum` interface type /// An `enum` interface type
@@ -151,6 +163,10 @@ impl Enum {
pub(crate) fn variant_info(&self) -> &VariantInfo { pub(crate) fn variant_info(&self) -> &VariantInfo {
&self.0.types[self.0.index].info &self.0.types[self.0.index].info
} }
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
&self.0.types[self.0.index].abi
}
} }
/// A `union` interface type /// A `union` interface type
@@ -174,6 +190,10 @@ impl Union {
pub(crate) fn variant_info(&self) -> &VariantInfo { pub(crate) fn variant_info(&self) -> &VariantInfo {
&self.0.types[self.0.index].info &self.0.types[self.0.index].info
} }
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
&self.0.types[self.0.index].abi
}
} }
/// An `option` interface type /// An `option` interface type
@@ -194,6 +214,10 @@ impl Option {
pub(crate) fn variant_info(&self) -> &VariantInfo { pub(crate) fn variant_info(&self) -> &VariantInfo {
&self.0.types[self.0.index].info &self.0.types[self.0.index].info
} }
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
&self.0.types[self.0.index].abi
}
} }
/// An `expected` interface type /// An `expected` interface type
@@ -219,6 +243,10 @@ impl Expected {
pub(crate) fn variant_info(&self) -> &VariantInfo { pub(crate) fn variant_info(&self) -> &VariantInfo {
&self.0.types[self.0.index].info &self.0.types[self.0.index].info
} }
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
&self.0.types[self.0.index].abi
}
} }
/// A `flags` interface type /// A `flags` interface type
@@ -238,6 +266,10 @@ impl Flags {
.iter() .iter()
.map(|name| name.deref()) .map(|name| name.deref())
} }
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
&self.0.types[self.0.index].abi
}
} }
/// Represents a component model interface type /// Represents a component model interface type
@@ -483,60 +515,6 @@ impl Type {
} }
} }
/// Return the number of stack slots needed to store values of this type in lowered form.
pub(crate) fn flatten_count(&self) -> usize {
match self {
Type::Unit => 0,
Type::Bool
| Type::S8
| Type::U8
| Type::S16
| Type::U16
| Type::S32
| Type::U32
| Type::S64
| Type::U64
| Type::Float32
| Type::Float64
| Type::Char
| Type::Enum(_) => 1,
Type::String | Type::List(_) => 2,
Type::Record(handle) => handle.fields().map(|field| field.ty.flatten_count()).sum(),
Type::Tuple(handle) => handle.types().map(|ty| ty.flatten_count()).sum(),
Type::Variant(handle) => {
1 + handle
.cases()
.map(|case| case.ty.flatten_count())
.max()
.unwrap_or(0)
}
Type::Union(handle) => {
1 + handle
.types()
.map(|ty| ty.flatten_count())
.max()
.unwrap_or(0)
}
Type::Option(handle) => 1 + handle.ty().flatten_count(),
Type::Expected(handle) => {
1 + handle
.ok()
.flatten_count()
.max(handle.err().flatten_count())
}
Type::Flags(handle) => values::u32_count_for_flag_count(handle.names().len()),
}
}
fn desc(&self) -> &'static str { fn desc(&self) -> &'static str {
match self { match self {
Type::Unit => "unit", Type::Unit => "unit",
@@ -574,14 +552,14 @@ impl Type {
Type::S32 | Type::U32 | Type::Char | Type::Float32 => &CanonicalAbiInfo::SCALAR4, Type::S32 | Type::U32 | Type::Char | Type::Float32 => &CanonicalAbiInfo::SCALAR4,
Type::S64 | Type::U64 | Type::Float64 => &CanonicalAbiInfo::SCALAR8, Type::S64 | Type::U64 | Type::Float64 => &CanonicalAbiInfo::SCALAR8,
Type::String | Type::List(_) => &CanonicalAbiInfo::POINTER_PAIR, Type::String | Type::List(_) => &CanonicalAbiInfo::POINTER_PAIR,
Type::Record(handle) => &handle.0.types[handle.0.index].abi, Type::Record(handle) => handle.canonical_abi(),
Type::Tuple(handle) => &handle.0.types[handle.0.index].abi, Type::Tuple(handle) => handle.canonical_abi(),
Type::Variant(handle) => &handle.0.types[handle.0.index].abi, Type::Variant(handle) => handle.canonical_abi(),
Type::Enum(handle) => &handle.0.types[handle.0.index].abi, Type::Enum(handle) => handle.canonical_abi(),
Type::Union(handle) => &handle.0.types[handle.0.index].abi, Type::Union(handle) => handle.canonical_abi(),
Type::Option(handle) => &handle.0.types[handle.0.index].abi, Type::Option(handle) => handle.canonical_abi(),
Type::Expected(handle) => &handle.0.types[handle.0.index].abi, Type::Expected(handle) => handle.canonical_abi(),
Type::Flags(handle) => &handle.0.types[handle.0.index].abi, Type::Flags(handle) => handle.canonical_abi(),
} }
} }
} }

View File

@@ -440,7 +440,8 @@ impl Flags {
.map(|(index, name)| (name, index)) .map(|(index, name)| (name, index))
.collect::<HashMap<_, _>>(); .collect::<HashMap<_, _>>();
let mut values = vec![0_u32; u32_count_for_flag_count(ty.names().len())]; let count = usize::from(ty.canonical_abi().flat_count.unwrap());
let mut values = vec![0_u32; count];
for name in names { for name in names {
let index = map let index = map
@@ -611,7 +612,7 @@ impl Val {
}), }),
Type::Variant(handle) => { Type::Variant(handle) => {
let (discriminant, value) = lift_variant( let (discriminant, value) = lift_variant(
ty.flatten_count(), handle.canonical_abi().flat_count(usize::MAX).unwrap(),
handle.cases().map(|case| case.ty), handle.cases().map(|case| case.ty),
store, store,
options, options,
@@ -626,7 +627,7 @@ impl Val {
} }
Type::Enum(handle) => { Type::Enum(handle) => {
let (discriminant, _) = lift_variant( let (discriminant, _) = lift_variant(
ty.flatten_count(), handle.canonical_abi().flat_count(usize::MAX).unwrap(),
handle.names().map(|_| Type::Unit), handle.names().map(|_| Type::Unit),
store, store,
options, options,
@@ -639,8 +640,13 @@ impl Val {
}) })
} }
Type::Union(handle) => { Type::Union(handle) => {
let (discriminant, value) = let (discriminant, value) = lift_variant(
lift_variant(ty.flatten_count(), handle.types(), store, options, src)?; handle.canonical_abi().flat_count(usize::MAX).unwrap(),
handle.types(),
store,
options,
src,
)?;
Val::Union(Union { Val::Union(Union {
ty: handle.clone(), ty: handle.clone(),
@@ -650,7 +656,7 @@ impl Val {
} }
Type::Option(handle) => { Type::Option(handle) => {
let (discriminant, value) = lift_variant( let (discriminant, value) = lift_variant(
ty.flatten_count(), handle.canonical_abi().flat_count(usize::MAX).unwrap(),
[Type::Unit, handle.ty()].into_iter(), [Type::Unit, handle.ty()].into_iter(),
store, store,
options, options,
@@ -665,7 +671,7 @@ impl Val {
} }
Type::Expected(handle) => { Type::Expected(handle) => {
let (discriminant, value) = lift_variant( let (discriminant, value) = lift_variant(
ty.flatten_count(), handle.canonical_abi().flat_count(usize::MAX).unwrap(),
[handle.ok(), handle.err()].into_iter(), [handle.ok(), handle.err()].into_iter(),
store, store,
options, options,
@@ -680,8 +686,9 @@ impl Val {
} }
Type::Flags(handle) => { Type::Flags(handle) => {
let count = u32::try_from(handle.names().len()).unwrap(); let count = u32::try_from(handle.names().len()).unwrap();
let u32_count = handle.canonical_abi().flat_count(usize::MAX).unwrap();
let value = iter::repeat_with(|| u32::lift(store, options, next(src))) let value = iter::repeat_with(|| u32::lift(store, options, next(src)))
.take(u32_count_for_flag_count(count.try_into()?)) .take(u32_count)
.collect::<Result<_>>()?; .collect::<Result<_>>()?;
Val::Flags(Flags { Val::Flags(Flags {
@@ -797,7 +804,7 @@ impl Val {
FlagsSize::Size1 => iter::once(u8::load(mem, bytes)? as u32).collect(), FlagsSize::Size1 => iter::once(u8::load(mem, bytes)? as u32).collect(),
FlagsSize::Size2 => iter::once(u16::load(mem, bytes)? as u32).collect(), FlagsSize::Size2 => iter::once(u16::load(mem, bytes)? as u32).collect(),
FlagsSize::Size4Plus(n) => (0..n) FlagsSize::Size4Plus(n) => (0..n)
.map(|index| u32::load(mem, &bytes[index * 4..][..4])) .map(|index| u32::load(mem, &bytes[usize::from(index) * 4..][..4]))
.collect::<Result<_>>()?, .collect::<Result<_>>()?,
}, },
}), }),
@@ -868,7 +875,9 @@ impl Val {
}) => { }) => {
next_mut(dst).write(ValRaw::u32(*discriminant)); next_mut(dst).write(ValRaw::u32(*discriminant));
value.lower(store, options, dst)?; value.lower(store, options, dst)?;
for _ in (1 + value.ty().flatten_count())..self.ty().flatten_count() { let value_flat = value.ty().canonical_abi().flat_count(usize::MAX).unwrap();
let variant_flat = self.ty().canonical_abi().flat_count(usize::MAX).unwrap();
for _ in (1 + value_flat)..variant_flat {
next_mut(dst).write(ValRaw::u32(0)); next_mut(dst).write(ValRaw::u32(0));
} }
} }
@@ -1070,7 +1079,8 @@ fn lift_variant<'a>(
.nth(discriminant as usize) .nth(discriminant as usize)
.ok_or_else(|| anyhow!("discriminant {} out of range [0..{})", discriminant, len))?; .ok_or_else(|| anyhow!("discriminant {} out of range [0..{})", discriminant, len))?;
let value = Val::lift(&ty, store, options, src)?; let value = Val::lift(&ty, store, options, src)?;
for _ in (1 + ty.flatten_count())..flatten_count { let value_flat = ty.canonical_abi().flat_count(usize::MAX).unwrap();
for _ in (1 + value_flat)..flatten_count {
next(src); next(src);
} }
Ok((discriminant, value)) Ok((discriminant, value))
@@ -1098,17 +1108,6 @@ fn lower_list<T>(
Ok((ptr, items.len())) Ok((ptr, items.len()))
} }
/// Calculate the size of a u32 array needed to represent the specified number of bit flags.
///
/// Note that this will always return at least 1, even if the `count` parameter is zero.
pub(crate) fn u32_count_for_flag_count(count: usize) -> usize {
match FlagsSize::from_count(count) {
FlagsSize::Size0 => 0,
FlagsSize::Size1 | FlagsSize::Size2 => 1,
FlagsSize::Size4Plus(n) => n,
}
}
fn next<'a>(src: &mut std::slice::Iter<'a, ValRaw>) -> &'a ValRaw { fn next<'a>(src: &mut std::slice::Iter<'a, ValRaw>) -> &'a ValRaw {
src.next().unwrap() src.next().unwrap()
} }