Refactor and optimize the flat type calculations (#4708)

* Optimize flat type representation calculations

Previously calculating the flat type representation would be done
recursively for an entire type tree every time it was visited.
Additionally the flat type representation was entirely built only to be
thrown away if it was too large at the end. This chiefly presented a
source of recursion based on the type structure in the component model
which fuzzing does not like as it reports stack overflows.

This commit overhauls the representation of flat types in Wasmtime by
caching the representation for each type in the compile-time
`ComponentTypesBuilder` structure. This avoids recalculating each time
the flat representation is queried and additionally allows opportunity
to have more short-circuiting to avoid building overly-large vectors.

* Remove duplicate flat count calculation in wasmtime

Roughly share the infrastructure in the `wasmtime-environ` crate, namely
the non-recursive and memoizing nature of the calculation.

* Fix component fuzz build

* Fix example compile
This commit is contained in:
Alex Crichton
2022-08-16 13:31:47 -05:00
committed by GitHub
parent 3c1490dd59
commit bc8e36a6af
14 changed files with 561 additions and 279 deletions

View File

@@ -16,7 +16,7 @@
//! can be somewhat arbitrary, an intentional decision.
use crate::component::{
CanonicalAbiInfo, ComponentTypes, InterfaceType, StringEncoding, TypeEnumIndex,
CanonicalAbiInfo, ComponentTypesBuilder, InterfaceType, StringEncoding, TypeEnumIndex,
TypeExpectedIndex, TypeFlagsIndex, TypeInterfaceIndex, TypeOptionIndex, TypeRecordIndex,
TypeTupleIndex, TypeUnionIndex, TypeVariantIndex, VariantInfo, FLAG_MAY_ENTER, FLAG_MAY_LEAVE,
MAX_FLAT_PARAMS, MAX_FLAT_RESULTS,
@@ -36,7 +36,7 @@ const MAX_STRING_BYTE_LENGTH: u32 = 1 << 31;
const UTF16_TAG: u32 = 1 << 31;
struct Compiler<'a, 'b> {
types: &'a ComponentTypes,
types: &'a ComponentTypesBuilder,
module: &'b mut Module<'a>,
result: FunctionId,
@@ -279,14 +279,16 @@ impl Compiler<'_, '_> {
// TODO: handle subtyping
assert_eq!(src_tys.len(), dst_tys.len());
let src_flat = self
.types
.flatten_types(lower_opts, src_tys.iter().copied());
let dst_flat = self.types.flatten_types(lift_opts, dst_tys.iter().copied());
let src_flat =
self.types
.flatten_types(lower_opts, MAX_FLAT_PARAMS, src_tys.iter().copied());
let dst_flat =
self.types
.flatten_types(lift_opts, MAX_FLAT_PARAMS, dst_tys.iter().copied());
let src = if src_flat.len() <= MAX_FLAT_PARAMS {
let src = if let Some(flat) = &src_flat {
Source::Stack(Stack {
locals: &param_locals[..src_flat.len()],
locals: &param_locals[..flat.len()],
opts: lower_opts,
})
} else {
@@ -303,8 +305,8 @@ impl Compiler<'_, '_> {
Source::Memory(self.memory_operand(lower_opts, TempLocal::new(addr, ty), align))
};
let dst = if dst_flat.len() <= MAX_FLAT_PARAMS {
Destination::Stack(&dst_flat, lift_opts)
let dst = if let Some(flat) = &dst_flat {
Destination::Stack(flat, lift_opts)
} else {
// If there are too many parameters then space is allocated in the
// destination module for the parameters via its `realloc` function.
@@ -348,10 +350,14 @@ impl Compiler<'_, '_> {
let lift_opts = &adapter.lift.options;
let lower_opts = &adapter.lower.options;
let src_flat = self.types.flatten_types(lift_opts, [src_ty]);
let dst_flat = self.types.flatten_types(lower_opts, [dst_ty]);
let src_flat = self
.types
.flatten_types(lift_opts, MAX_FLAT_RESULTS, [src_ty]);
let dst_flat = self
.types
.flatten_types(lower_opts, MAX_FLAT_RESULTS, [dst_ty]);
let src = if src_flat.len() <= MAX_FLAT_RESULTS {
let src = if src_flat.is_some() {
Source::Stack(Stack {
locals: result_locals,
opts: lift_opts,
@@ -368,8 +374,8 @@ impl Compiler<'_, '_> {
Source::Memory(self.memory_operand(lift_opts, TempLocal::new(addr, ty), align))
};
let dst = if dst_flat.len() <= MAX_FLAT_RESULTS {
Destination::Stack(&dst_flat, lower_opts)
let dst = if let Some(flat) = &dst_flat {
Destination::Stack(flat, lower_opts)
} else {
// This is slightly different than `translate_params` where the
// return pointer was provided by the caller of this function
@@ -1937,6 +1943,7 @@ impl Compiler<'_, '_> {
FlagsSize::Size4Plus(n) => {
let srcs = src.record_field_srcs(self.types, (0..n).map(|_| InterfaceType::U32));
let dsts = dst.record_field_dsts(self.types, (0..n).map(|_| InterfaceType::U32));
let n = usize::from(n);
for (i, (src, dst)) in srcs.zip(dsts).enumerate() {
let mask = if i == n - 1 && (cnt % 32 != 0) {
(1 << (cnt % 32)) - 1
@@ -2775,7 +2782,7 @@ impl<'a> Source<'a> {
/// offset for each memory-based type.
fn record_field_srcs<'b>(
&'b self,
types: &'b ComponentTypes,
types: &'b ComponentTypesBuilder,
fields: impl IntoIterator<Item = InterfaceType> + 'b,
) -> impl Iterator<Item = Source<'a>> + 'b
where
@@ -2788,7 +2795,7 @@ impl<'a> Source<'a> {
Source::Memory(mem)
}
Source::Stack(stack) => {
let cnt = types.flatten_types(stack.opts, [ty]).len() as u32;
let cnt = types.flat_types(&ty).unwrap().len() as u32;
offset += cnt;
Source::Stack(stack.slice((offset - cnt) as usize..offset as usize))
}
@@ -2798,13 +2805,13 @@ impl<'a> Source<'a> {
/// Returns the corresponding discriminant source and payload source f
fn payload_src(
&self,
types: &ComponentTypes,
types: &ComponentTypesBuilder,
info: &VariantInfo,
case: &InterfaceType,
) -> Source<'a> {
match self {
Source::Stack(s) => {
let flat_len = types.flatten_types(s.opts, [*case]).len();
let flat_len = types.flat_types(case).unwrap().len();
Source::Stack(s.slice(1..s.locals.len()).slice(0..flat_len))
}
Source::Memory(mem) => {
@@ -2830,7 +2837,7 @@ impl<'a> Destination<'a> {
/// Same as `Source::record_field_srcs` but for destinations.
fn record_field_dsts<'b>(
&'b self,
types: &'b ComponentTypes,
types: &'b ComponentTypesBuilder,
fields: impl IntoIterator<Item = InterfaceType> + 'b,
) -> impl Iterator<Item = Destination> + 'b
where
@@ -2843,7 +2850,7 @@ impl<'a> Destination<'a> {
Destination::Memory(mem)
}
Destination::Stack(s, opts) => {
let cnt = types.flatten_types(opts, [ty]).len() as u32;
let cnt = types.flat_types(&ty).unwrap().len() as u32;
offset += cnt;
Destination::Stack(&s[(offset - cnt) as usize..offset as usize], opts)
}
@@ -2853,13 +2860,13 @@ impl<'a> Destination<'a> {
/// Returns the corresponding discriminant source and payload source f
fn payload_dst(
&self,
types: &ComponentTypes,
types: &ComponentTypesBuilder,
info: &VariantInfo,
case: &InterfaceType,
) -> Destination {
match self {
Destination::Stack(s, opts) => {
let flat_len = types.flatten_types(opts, [*case]).len();
let flat_len = types.flat_types(case).unwrap().len();
Destination::Stack(&s[1..][..flat_len], opts)
}
Destination::Memory(mem) => {
@@ -2883,7 +2890,7 @@ impl<'a> Destination<'a> {
fn next_field_offset<'a>(
offset: &mut u32,
types: &ComponentTypes,
types: &ComponentTypesBuilder,
field: &InterfaceType,
mem: &Memory<'a>,
) -> Memory<'a> {
@@ -2930,7 +2937,7 @@ struct VariantCase<'a> {
dst_ty: &'a InterfaceType,
}
fn variant_info<I>(types: &ComponentTypes, cases: I) -> VariantInfo
fn variant_info<I>(types: &ComponentTypesBuilder, cases: I) -> VariantInfo
where
I: IntoIterator<Item = InterfaceType>,
I::IntoIter: ExactSizeIterator,