Refactor and optimize the flat type calculations (#4708)
* Optimize flat type representation calculations Previously calculating the flat type representation would be done recursively for an entire type tree every time it was visited. Additionally the flat type representation was entirely built only to be thrown away if it was too large at the end. This chiefly presented a source of recursion based on the type structure in the component model which fuzzing does not like as it reports stack overflows. This commit overhauls the representation of flat types in Wasmtime by caching the representation for each type in the compile-time `ComponentTypesBuilder` structure. This avoids recalculating each time the flat representation is queried and additionally allows opportunity to have more short-circuiting to avoid building overly-large vectors. * Remove duplicate flat count calculation in wasmtime Roughly share the infrastructure in the `wasmtime-environ` crate, namely the non-recursive and memoizing nature of the calculation. * Fix component fuzz build * Fix example compile
This commit is contained in:
@@ -302,15 +302,15 @@ impl Func {
|
||||
result = Type::from(&ty.result, &data.types);
|
||||
}
|
||||
|
||||
let param_count = params.iter().map(|ty| ty.flatten_count()).sum::<usize>();
|
||||
let result_count = result.flatten_count();
|
||||
let param_abi = CanonicalAbiInfo::record(params.iter().map(|t| t.canonical_abi()));
|
||||
let result_count = result.canonical_abi().flat_count(MAX_FLAT_RESULTS);
|
||||
|
||||
self.call_raw(
|
||||
store,
|
||||
args,
|
||||
|store, options, args, dst: &mut MaybeUninit<[ValRaw; MAX_FLAT_PARAMS]>| {
|
||||
if param_count > MAX_FLAT_PARAMS {
|
||||
self.store_args(store, &options, ¶ms, args, dst)
|
||||
if param_abi.flat_count(MAX_FLAT_PARAMS).is_none() {
|
||||
self.store_args(store, &options, ¶m_abi, ¶ms, args, dst)
|
||||
} else {
|
||||
dst.write([ValRaw::u64(0); MAX_FLAT_PARAMS]);
|
||||
|
||||
@@ -324,7 +324,7 @@ impl Func {
|
||||
}
|
||||
},
|
||||
|store, options, src: &[ValRaw; MAX_FLAT_RESULTS]| {
|
||||
if result_count > MAX_FLAT_RESULTS {
|
||||
if result_count.is_none() {
|
||||
Self::load_result(&Memory::new(store, &options), &result, &mut src.iter())
|
||||
} else {
|
||||
Val::lift(&result, store, &options, &mut src.iter())
|
||||
@@ -554,12 +554,11 @@ impl Func {
|
||||
&self,
|
||||
store: &mut StoreContextMut<'_, T>,
|
||||
options: &Options,
|
||||
abi: &CanonicalAbiInfo,
|
||||
params: &[Type],
|
||||
args: &[Val],
|
||||
dst: &mut MaybeUninit<[ValRaw; MAX_FLAT_PARAMS]>,
|
||||
) -> Result<()> {
|
||||
let abi = CanonicalAbiInfo::record(params.iter().map(|t| t.canonical_abi()));
|
||||
|
||||
let mut memory = MemoryMut::new(store.as_context_mut(), options);
|
||||
let size = usize::try_from(abi.size32).unwrap();
|
||||
let ptr = memory.realloc(0, 0, abi.align32, size)?;
|
||||
|
||||
@@ -400,12 +400,19 @@ where
|
||||
bail!("cannot leave component instance");
|
||||
}
|
||||
|
||||
let param_count = params.iter().map(|ty| ty.flatten_count()).sum::<usize>();
|
||||
|
||||
let args;
|
||||
let ret_index;
|
||||
|
||||
if param_count <= MAX_FLAT_PARAMS {
|
||||
let param_abi = CanonicalAbiInfo::record(params.iter().map(|t| t.canonical_abi()));
|
||||
let param_count = param_abi.flat_count.and_then(|i| {
|
||||
let i = usize::from(i);
|
||||
if i > MAX_FLAT_PARAMS {
|
||||
None
|
||||
} else {
|
||||
Some(i)
|
||||
}
|
||||
});
|
||||
if let Some(param_count) = param_count {
|
||||
let iter = &mut storage.iter();
|
||||
args = params
|
||||
.iter()
|
||||
@@ -413,8 +420,6 @@ where
|
||||
.collect::<Result<Box<[_]>>>()?;
|
||||
ret_index = param_count;
|
||||
} else {
|
||||
let param_abi = CanonicalAbiInfo::record(params.iter().map(|t| t.canonical_abi()));
|
||||
|
||||
let memory = Memory::new(cx.0, &options);
|
||||
let mut offset = validate_inbounds_dynamic(¶m_abi, memory.as_slice(), &storage[0])?;
|
||||
args = params
|
||||
@@ -436,8 +441,8 @@ where
|
||||
flags.set_may_leave(false);
|
||||
result.check(&ret)?;
|
||||
|
||||
let result_count = result.flatten_count();
|
||||
if result_count <= MAX_FLAT_RESULTS {
|
||||
let result_count = result.canonical_abi().flat_count(MAX_FLAT_RESULTS);
|
||||
if result_count.is_some() {
|
||||
let dst = mem::transmute::<&mut [ValRaw], &mut [MaybeUninit<ValRaw>]>(storage);
|
||||
ret.lower(&mut cx, &options, &mut dst.iter_mut())?;
|
||||
} else {
|
||||
|
||||
@@ -78,6 +78,10 @@ impl Record {
|
||||
ty: Type::from(&field.ty, &self.0.types),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
|
||||
&self.0.types[self.0.index].abi
|
||||
}
|
||||
}
|
||||
|
||||
/// A `tuple` interface type
|
||||
@@ -97,6 +101,10 @@ impl Tuple {
|
||||
.iter()
|
||||
.map(|ty| Type::from(ty, &self.0.types))
|
||||
}
|
||||
|
||||
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
|
||||
&self.0.types[self.0.index].abi
|
||||
}
|
||||
}
|
||||
|
||||
/// A case declaration belonging to a `variant`
|
||||
@@ -128,6 +136,10 @@ impl Variant {
|
||||
pub(crate) fn variant_info(&self) -> &VariantInfo {
|
||||
&self.0.types[self.0.index].info
|
||||
}
|
||||
|
||||
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
|
||||
&self.0.types[self.0.index].abi
|
||||
}
|
||||
}
|
||||
|
||||
/// An `enum` interface type
|
||||
@@ -151,6 +163,10 @@ impl Enum {
|
||||
pub(crate) fn variant_info(&self) -> &VariantInfo {
|
||||
&self.0.types[self.0.index].info
|
||||
}
|
||||
|
||||
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
|
||||
&self.0.types[self.0.index].abi
|
||||
}
|
||||
}
|
||||
|
||||
/// A `union` interface type
|
||||
@@ -174,6 +190,10 @@ impl Union {
|
||||
pub(crate) fn variant_info(&self) -> &VariantInfo {
|
||||
&self.0.types[self.0.index].info
|
||||
}
|
||||
|
||||
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
|
||||
&self.0.types[self.0.index].abi
|
||||
}
|
||||
}
|
||||
|
||||
/// An `option` interface type
|
||||
@@ -194,6 +214,10 @@ impl Option {
|
||||
pub(crate) fn variant_info(&self) -> &VariantInfo {
|
||||
&self.0.types[self.0.index].info
|
||||
}
|
||||
|
||||
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
|
||||
&self.0.types[self.0.index].abi
|
||||
}
|
||||
}
|
||||
|
||||
/// An `expected` interface type
|
||||
@@ -219,6 +243,10 @@ impl Expected {
|
||||
pub(crate) fn variant_info(&self) -> &VariantInfo {
|
||||
&self.0.types[self.0.index].info
|
||||
}
|
||||
|
||||
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
|
||||
&self.0.types[self.0.index].abi
|
||||
}
|
||||
}
|
||||
|
||||
/// A `flags` interface type
|
||||
@@ -238,6 +266,10 @@ impl Flags {
|
||||
.iter()
|
||||
.map(|name| name.deref())
|
||||
}
|
||||
|
||||
pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo {
|
||||
&self.0.types[self.0.index].abi
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a component model interface type
|
||||
@@ -483,60 +515,6 @@ impl Type {
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the number of stack slots needed to store values of this type in lowered form.
|
||||
pub(crate) fn flatten_count(&self) -> usize {
|
||||
match self {
|
||||
Type::Unit => 0,
|
||||
|
||||
Type::Bool
|
||||
| Type::S8
|
||||
| Type::U8
|
||||
| Type::S16
|
||||
| Type::U16
|
||||
| Type::S32
|
||||
| Type::U32
|
||||
| Type::S64
|
||||
| Type::U64
|
||||
| Type::Float32
|
||||
| Type::Float64
|
||||
| Type::Char
|
||||
| Type::Enum(_) => 1,
|
||||
|
||||
Type::String | Type::List(_) => 2,
|
||||
|
||||
Type::Record(handle) => handle.fields().map(|field| field.ty.flatten_count()).sum(),
|
||||
|
||||
Type::Tuple(handle) => handle.types().map(|ty| ty.flatten_count()).sum(),
|
||||
|
||||
Type::Variant(handle) => {
|
||||
1 + handle
|
||||
.cases()
|
||||
.map(|case| case.ty.flatten_count())
|
||||
.max()
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
Type::Union(handle) => {
|
||||
1 + handle
|
||||
.types()
|
||||
.map(|ty| ty.flatten_count())
|
||||
.max()
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
Type::Option(handle) => 1 + handle.ty().flatten_count(),
|
||||
|
||||
Type::Expected(handle) => {
|
||||
1 + handle
|
||||
.ok()
|
||||
.flatten_count()
|
||||
.max(handle.err().flatten_count())
|
||||
}
|
||||
|
||||
Type::Flags(handle) => values::u32_count_for_flag_count(handle.names().len()),
|
||||
}
|
||||
}
|
||||
|
||||
fn desc(&self) -> &'static str {
|
||||
match self {
|
||||
Type::Unit => "unit",
|
||||
@@ -574,14 +552,14 @@ impl Type {
|
||||
Type::S32 | Type::U32 | Type::Char | Type::Float32 => &CanonicalAbiInfo::SCALAR4,
|
||||
Type::S64 | Type::U64 | Type::Float64 => &CanonicalAbiInfo::SCALAR8,
|
||||
Type::String | Type::List(_) => &CanonicalAbiInfo::POINTER_PAIR,
|
||||
Type::Record(handle) => &handle.0.types[handle.0.index].abi,
|
||||
Type::Tuple(handle) => &handle.0.types[handle.0.index].abi,
|
||||
Type::Variant(handle) => &handle.0.types[handle.0.index].abi,
|
||||
Type::Enum(handle) => &handle.0.types[handle.0.index].abi,
|
||||
Type::Union(handle) => &handle.0.types[handle.0.index].abi,
|
||||
Type::Option(handle) => &handle.0.types[handle.0.index].abi,
|
||||
Type::Expected(handle) => &handle.0.types[handle.0.index].abi,
|
||||
Type::Flags(handle) => &handle.0.types[handle.0.index].abi,
|
||||
Type::Record(handle) => handle.canonical_abi(),
|
||||
Type::Tuple(handle) => handle.canonical_abi(),
|
||||
Type::Variant(handle) => handle.canonical_abi(),
|
||||
Type::Enum(handle) => handle.canonical_abi(),
|
||||
Type::Union(handle) => handle.canonical_abi(),
|
||||
Type::Option(handle) => handle.canonical_abi(),
|
||||
Type::Expected(handle) => handle.canonical_abi(),
|
||||
Type::Flags(handle) => handle.canonical_abi(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -440,7 +440,8 @@ impl Flags {
|
||||
.map(|(index, name)| (name, index))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let mut values = vec![0_u32; u32_count_for_flag_count(ty.names().len())];
|
||||
let count = usize::from(ty.canonical_abi().flat_count.unwrap());
|
||||
let mut values = vec![0_u32; count];
|
||||
|
||||
for name in names {
|
||||
let index = map
|
||||
@@ -611,7 +612,7 @@ impl Val {
|
||||
}),
|
||||
Type::Variant(handle) => {
|
||||
let (discriminant, value) = lift_variant(
|
||||
ty.flatten_count(),
|
||||
handle.canonical_abi().flat_count(usize::MAX).unwrap(),
|
||||
handle.cases().map(|case| case.ty),
|
||||
store,
|
||||
options,
|
||||
@@ -626,7 +627,7 @@ impl Val {
|
||||
}
|
||||
Type::Enum(handle) => {
|
||||
let (discriminant, _) = lift_variant(
|
||||
ty.flatten_count(),
|
||||
handle.canonical_abi().flat_count(usize::MAX).unwrap(),
|
||||
handle.names().map(|_| Type::Unit),
|
||||
store,
|
||||
options,
|
||||
@@ -639,8 +640,13 @@ impl Val {
|
||||
})
|
||||
}
|
||||
Type::Union(handle) => {
|
||||
let (discriminant, value) =
|
||||
lift_variant(ty.flatten_count(), handle.types(), store, options, src)?;
|
||||
let (discriminant, value) = lift_variant(
|
||||
handle.canonical_abi().flat_count(usize::MAX).unwrap(),
|
||||
handle.types(),
|
||||
store,
|
||||
options,
|
||||
src,
|
||||
)?;
|
||||
|
||||
Val::Union(Union {
|
||||
ty: handle.clone(),
|
||||
@@ -650,7 +656,7 @@ impl Val {
|
||||
}
|
||||
Type::Option(handle) => {
|
||||
let (discriminant, value) = lift_variant(
|
||||
ty.flatten_count(),
|
||||
handle.canonical_abi().flat_count(usize::MAX).unwrap(),
|
||||
[Type::Unit, handle.ty()].into_iter(),
|
||||
store,
|
||||
options,
|
||||
@@ -665,7 +671,7 @@ impl Val {
|
||||
}
|
||||
Type::Expected(handle) => {
|
||||
let (discriminant, value) = lift_variant(
|
||||
ty.flatten_count(),
|
||||
handle.canonical_abi().flat_count(usize::MAX).unwrap(),
|
||||
[handle.ok(), handle.err()].into_iter(),
|
||||
store,
|
||||
options,
|
||||
@@ -680,8 +686,9 @@ impl Val {
|
||||
}
|
||||
Type::Flags(handle) => {
|
||||
let count = u32::try_from(handle.names().len()).unwrap();
|
||||
let u32_count = handle.canonical_abi().flat_count(usize::MAX).unwrap();
|
||||
let value = iter::repeat_with(|| u32::lift(store, options, next(src)))
|
||||
.take(u32_count_for_flag_count(count.try_into()?))
|
||||
.take(u32_count)
|
||||
.collect::<Result<_>>()?;
|
||||
|
||||
Val::Flags(Flags {
|
||||
@@ -797,7 +804,7 @@ impl Val {
|
||||
FlagsSize::Size1 => iter::once(u8::load(mem, bytes)? as u32).collect(),
|
||||
FlagsSize::Size2 => iter::once(u16::load(mem, bytes)? as u32).collect(),
|
||||
FlagsSize::Size4Plus(n) => (0..n)
|
||||
.map(|index| u32::load(mem, &bytes[index * 4..][..4]))
|
||||
.map(|index| u32::load(mem, &bytes[usize::from(index) * 4..][..4]))
|
||||
.collect::<Result<_>>()?,
|
||||
},
|
||||
}),
|
||||
@@ -868,7 +875,9 @@ impl Val {
|
||||
}) => {
|
||||
next_mut(dst).write(ValRaw::u32(*discriminant));
|
||||
value.lower(store, options, dst)?;
|
||||
for _ in (1 + value.ty().flatten_count())..self.ty().flatten_count() {
|
||||
let value_flat = value.ty().canonical_abi().flat_count(usize::MAX).unwrap();
|
||||
let variant_flat = self.ty().canonical_abi().flat_count(usize::MAX).unwrap();
|
||||
for _ in (1 + value_flat)..variant_flat {
|
||||
next_mut(dst).write(ValRaw::u32(0));
|
||||
}
|
||||
}
|
||||
@@ -1070,7 +1079,8 @@ fn lift_variant<'a>(
|
||||
.nth(discriminant as usize)
|
||||
.ok_or_else(|| anyhow!("discriminant {} out of range [0..{})", discriminant, len))?;
|
||||
let value = Val::lift(&ty, store, options, src)?;
|
||||
for _ in (1 + ty.flatten_count())..flatten_count {
|
||||
let value_flat = ty.canonical_abi().flat_count(usize::MAX).unwrap();
|
||||
for _ in (1 + value_flat)..flatten_count {
|
||||
next(src);
|
||||
}
|
||||
Ok((discriminant, value))
|
||||
@@ -1098,17 +1108,6 @@ fn lower_list<T>(
|
||||
Ok((ptr, items.len()))
|
||||
}
|
||||
|
||||
/// Calculate the size of a u32 array needed to represent the specified number of bit flags.
|
||||
///
|
||||
/// Note that this will always return at least 1, even if the `count` parameter is zero.
|
||||
pub(crate) fn u32_count_for_flag_count(count: usize) -> usize {
|
||||
match FlagsSize::from_count(count) {
|
||||
FlagsSize::Size0 => 0,
|
||||
FlagsSize::Size1 | FlagsSize::Size2 => 1,
|
||||
FlagsSize::Size4Plus(n) => n,
|
||||
}
|
||||
}
|
||||
|
||||
fn next<'a>(src: &mut std::slice::Iter<'a, ValRaw>) -> &'a ValRaw {
|
||||
src.next().unwrap()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user