Refactor and optimize the flat type calculations (#4708)

* Optimize flat type representation calculations

Previously calculating the flat type representation would be done
recursively for an entire type tree every time it was visited.
Additionally the flat type representation was entirely built only to be
thrown away if it was too large at the end. This chiefly presented a
source of recursion based on the type structure in the component model
which fuzzing does not like as it reports stack overflows.

This commit overhauls the representation of flat types in Wasmtime by
caching the representation for each type in the compile-time
`ComponentTypesBuilder` structure. This avoids recalculating each time
the flat representation is queried and additionally allows opportunity
to have more short-circuiting to avoid building overly-large vectors.

* Remove duplicate flat count calculation in wasmtime

Roughly share the infrastructure in the `wasmtime-environ` crate, namely
the non-recursive and memoizing nature of the calculation.

* Fix component fuzz build

* Fix example compile
This commit is contained in:
Alex Crichton
2022-08-16 13:31:47 -05:00
committed by GitHub
parent 3c1490dd59
commit bc8e36a6af
14 changed files with 561 additions and 279 deletions

View File

@@ -302,15 +302,15 @@ impl Func {
result = Type::from(&ty.result, &data.types);
}
let param_count = params.iter().map(|ty| ty.flatten_count()).sum::<usize>();
let result_count = result.flatten_count();
let param_abi = CanonicalAbiInfo::record(params.iter().map(|t| t.canonical_abi()));
let result_count = result.canonical_abi().flat_count(MAX_FLAT_RESULTS);
self.call_raw(
store,
args,
|store, options, args, dst: &mut MaybeUninit<[ValRaw; MAX_FLAT_PARAMS]>| {
if param_count > MAX_FLAT_PARAMS {
self.store_args(store, &options, &params, args, dst)
if param_abi.flat_count(MAX_FLAT_PARAMS).is_none() {
self.store_args(store, &options, &param_abi, &params, args, dst)
} else {
dst.write([ValRaw::u64(0); MAX_FLAT_PARAMS]);
@@ -324,7 +324,7 @@ impl Func {
}
},
|store, options, src: &[ValRaw; MAX_FLAT_RESULTS]| {
if result_count > MAX_FLAT_RESULTS {
if result_count.is_none() {
Self::load_result(&Memory::new(store, &options), &result, &mut src.iter())
} else {
Val::lift(&result, store, &options, &mut src.iter())
@@ -554,12 +554,11 @@ impl Func {
&self,
store: &mut StoreContextMut<'_, T>,
options: &Options,
abi: &CanonicalAbiInfo,
params: &[Type],
args: &[Val],
dst: &mut MaybeUninit<[ValRaw; MAX_FLAT_PARAMS]>,
) -> Result<()> {
let abi = CanonicalAbiInfo::record(params.iter().map(|t| t.canonical_abi()));
let mut memory = MemoryMut::new(store.as_context_mut(), options);
let size = usize::try_from(abi.size32).unwrap();
let ptr = memory.realloc(0, 0, abi.align32, size)?;

View File

@@ -400,12 +400,19 @@ where
bail!("cannot leave component instance");
}
let param_count = params.iter().map(|ty| ty.flatten_count()).sum::<usize>();
let args;
let ret_index;
if param_count <= MAX_FLAT_PARAMS {
let param_abi = CanonicalAbiInfo::record(params.iter().map(|t| t.canonical_abi()));
let param_count = param_abi.flat_count.and_then(|i| {
let i = usize::from(i);
if i > MAX_FLAT_PARAMS {
None
} else {
Some(i)
}
});
if let Some(param_count) = param_count {
let iter = &mut storage.iter();
args = params
.iter()
@@ -413,8 +420,6 @@ where
.collect::<Result<Box<[_]>>>()?;
ret_index = param_count;
} else {
let param_abi = CanonicalAbiInfo::record(params.iter().map(|t| t.canonical_abi()));
let memory = Memory::new(cx.0, &options);
let mut offset = validate_inbounds_dynamic(&param_abi, memory.as_slice(), &storage[0])?;
args = params
@@ -436,8 +441,8 @@ where
flags.set_may_leave(false);
result.check(&ret)?;
let result_count = result.flatten_count();
if result_count <= MAX_FLAT_RESULTS {
let result_count = result.canonical_abi().flat_count(MAX_FLAT_RESULTS);
if result_count.is_some() {
let dst = mem::transmute::<&mut [ValRaw], &mut [MaybeUninit<ValRaw>]>(storage);
ret.lower(&mut cx, &options, &mut dst.iter_mut())?;
} else {

View File

@@ -78,6 +78,10 @@ impl Record {
ty: Type::from(&field.ty, &self.0.types),
})
}
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
&self.0.types[self.0.index].abi
}
}
/// A `tuple` interface type
@@ -97,6 +101,10 @@ impl Tuple {
.iter()
.map(|ty| Type::from(ty, &self.0.types))
}
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
&self.0.types[self.0.index].abi
}
}
/// A case declaration belonging to a `variant`
@@ -128,6 +136,10 @@ impl Variant {
pub(crate) fn variant_info(&self) -> &VariantInfo {
&self.0.types[self.0.index].info
}
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
&self.0.types[self.0.index].abi
}
}
/// An `enum` interface type
@@ -151,6 +163,10 @@ impl Enum {
pub(crate) fn variant_info(&self) -> &VariantInfo {
&self.0.types[self.0.index].info
}
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
&self.0.types[self.0.index].abi
}
}
/// A `union` interface type
@@ -174,6 +190,10 @@ impl Union {
pub(crate) fn variant_info(&self) -> &VariantInfo {
&self.0.types[self.0.index].info
}
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
&self.0.types[self.0.index].abi
}
}
/// An `option` interface type
@@ -194,6 +214,10 @@ impl Option {
pub(crate) fn variant_info(&self) -> &VariantInfo {
&self.0.types[self.0.index].info
}
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
&self.0.types[self.0.index].abi
}
}
/// An `expected` interface type
@@ -219,6 +243,10 @@ impl Expected {
pub(crate) fn variant_info(&self) -> &VariantInfo {
&self.0.types[self.0.index].info
}
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
&self.0.types[self.0.index].abi
}
}
/// A `flags` interface type
@@ -238,6 +266,10 @@ impl Flags {
.iter()
.map(|name| name.deref())
}
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
&self.0.types[self.0.index].abi
}
}
/// Represents a component model interface type
@@ -483,60 +515,6 @@ impl Type {
}
}
/// Return the number of stack slots needed to store values of this type in lowered form.
pub(crate) fn flatten_count(&self) -> usize {
match self {
Type::Unit => 0,
Type::Bool
| Type::S8
| Type::U8
| Type::S16
| Type::U16
| Type::S32
| Type::U32
| Type::S64
| Type::U64
| Type::Float32
| Type::Float64
| Type::Char
| Type::Enum(_) => 1,
Type::String | Type::List(_) => 2,
Type::Record(handle) => handle.fields().map(|field| field.ty.flatten_count()).sum(),
Type::Tuple(handle) => handle.types().map(|ty| ty.flatten_count()).sum(),
Type::Variant(handle) => {
1 + handle
.cases()
.map(|case| case.ty.flatten_count())
.max()
.unwrap_or(0)
}
Type::Union(handle) => {
1 + handle
.types()
.map(|ty| ty.flatten_count())
.max()
.unwrap_or(0)
}
Type::Option(handle) => 1 + handle.ty().flatten_count(),
Type::Expected(handle) => {
1 + handle
.ok()
.flatten_count()
.max(handle.err().flatten_count())
}
Type::Flags(handle) => values::u32_count_for_flag_count(handle.names().len()),
}
}
fn desc(&self) -> &'static str {
match self {
Type::Unit => "unit",
@@ -574,14 +552,14 @@ impl Type {
Type::S32 | Type::U32 | Type::Char | Type::Float32 => &CanonicalAbiInfo::SCALAR4,
Type::S64 | Type::U64 | Type::Float64 => &CanonicalAbiInfo::SCALAR8,
Type::String | Type::List(_) => &CanonicalAbiInfo::POINTER_PAIR,
Type::Record(handle) => &handle.0.types[handle.0.index].abi,
Type::Tuple(handle) => &handle.0.types[handle.0.index].abi,
Type::Variant(handle) => &handle.0.types[handle.0.index].abi,
Type::Enum(handle) => &handle.0.types[handle.0.index].abi,
Type::Union(handle) => &handle.0.types[handle.0.index].abi,
Type::Option(handle) => &handle.0.types[handle.0.index].abi,
Type::Expected(handle) => &handle.0.types[handle.0.index].abi,
Type::Flags(handle) => &handle.0.types[handle.0.index].abi,
Type::Record(handle) => handle.canonical_abi(),
Type::Tuple(handle) => handle.canonical_abi(),
Type::Variant(handle) => handle.canonical_abi(),
Type::Enum(handle) => handle.canonical_abi(),
Type::Union(handle) => handle.canonical_abi(),
Type::Option(handle) => handle.canonical_abi(),
Type::Expected(handle) => handle.canonical_abi(),
Type::Flags(handle) => handle.canonical_abi(),
}
}
}

View File

@@ -440,7 +440,8 @@ impl Flags {
.map(|(index, name)| (name, index))
.collect::<HashMap<_, _>>();
let mut values = vec![0_u32; u32_count_for_flag_count(ty.names().len())];
let count = usize::from(ty.canonical_abi().flat_count.unwrap());
let mut values = vec![0_u32; count];
for name in names {
let index = map
@@ -611,7 +612,7 @@ impl Val {
}),
Type::Variant(handle) => {
let (discriminant, value) = lift_variant(
ty.flatten_count(),
handle.canonical_abi().flat_count(usize::MAX).unwrap(),
handle.cases().map(|case| case.ty),
store,
options,
@@ -626,7 +627,7 @@ impl Val {
}
Type::Enum(handle) => {
let (discriminant, _) = lift_variant(
ty.flatten_count(),
handle.canonical_abi().flat_count(usize::MAX).unwrap(),
handle.names().map(|_| Type::Unit),
store,
options,
@@ -639,8 +640,13 @@ impl Val {
})
}
Type::Union(handle) => {
let (discriminant, value) =
lift_variant(ty.flatten_count(), handle.types(), store, options, src)?;
let (discriminant, value) = lift_variant(
handle.canonical_abi().flat_count(usize::MAX).unwrap(),
handle.types(),
store,
options,
src,
)?;
Val::Union(Union {
ty: handle.clone(),
@@ -650,7 +656,7 @@ impl Val {
}
Type::Option(handle) => {
let (discriminant, value) = lift_variant(
ty.flatten_count(),
handle.canonical_abi().flat_count(usize::MAX).unwrap(),
[Type::Unit, handle.ty()].into_iter(),
store,
options,
@@ -665,7 +671,7 @@ impl Val {
}
Type::Expected(handle) => {
let (discriminant, value) = lift_variant(
ty.flatten_count(),
handle.canonical_abi().flat_count(usize::MAX).unwrap(),
[handle.ok(), handle.err()].into_iter(),
store,
options,
@@ -680,8 +686,9 @@ impl Val {
}
Type::Flags(handle) => {
let count = u32::try_from(handle.names().len()).unwrap();
let u32_count = handle.canonical_abi().flat_count(usize::MAX).unwrap();
let value = iter::repeat_with(|| u32::lift(store, options, next(src)))
.take(u32_count_for_flag_count(count.try_into()?))
.take(u32_count)
.collect::<Result<_>>()?;
Val::Flags(Flags {
@@ -797,7 +804,7 @@ impl Val {
FlagsSize::Size1 => iter::once(u8::load(mem, bytes)? as u32).collect(),
FlagsSize::Size2 => iter::once(u16::load(mem, bytes)? as u32).collect(),
FlagsSize::Size4Plus(n) => (0..n)
.map(|index| u32::load(mem, &bytes[index * 4..][..4]))
.map(|index| u32::load(mem, &bytes[usize::from(index) * 4..][..4]))
.collect::<Result<_>>()?,
},
}),
@@ -868,7 +875,9 @@ impl Val {
}) => {
next_mut(dst).write(ValRaw::u32(*discriminant));
value.lower(store, options, dst)?;
for _ in (1 + value.ty().flatten_count())..self.ty().flatten_count() {
let value_flat = value.ty().canonical_abi().flat_count(usize::MAX).unwrap();
let variant_flat = self.ty().canonical_abi().flat_count(usize::MAX).unwrap();
for _ in (1 + value_flat)..variant_flat {
next_mut(dst).write(ValRaw::u32(0));
}
}
@@ -1070,7 +1079,8 @@ fn lift_variant<'a>(
.nth(discriminant as usize)
.ok_or_else(|| anyhow!("discriminant {} out of range [0..{})", discriminant, len))?;
let value = Val::lift(&ty, store, options, src)?;
for _ in (1 + ty.flatten_count())..flatten_count {
let value_flat = ty.canonical_abi().flat_count(usize::MAX).unwrap();
for _ in (1 + value_flat)..flatten_count {
next(src);
}
Ok((discriminant, value))
@@ -1098,17 +1108,6 @@ fn lower_list<T>(
Ok((ptr, items.len()))
}
/// Calculate the size of a u32 array needed to represent the specified number of bit flags.
///
/// Note that this will always return at least 1, even if the `count` parameter is zero.
pub(crate) fn u32_count_for_flag_count(count: usize) -> usize {
match FlagsSize::from_count(count) {
FlagsSize::Size0 => 0,
FlagsSize::Size1 | FlagsSize::Size2 => 1,
FlagsSize::Size4Plus(n) => n,
}
}
fn next<'a>(src: &mut std::slice::Iter<'a, ValRaw>) -> &'a ValRaw {
src.next().unwrap()
}