diff --git a/crates/component-macro/src/lib.rs b/crates/component-macro/src/lib.rs index 7809caf7dd..afef164778 100644 --- a/crates/component-macro/src/lib.rs +++ b/crates/component-macro/src/lib.rs @@ -1015,7 +1015,7 @@ fn expand_flags(flags: &Flags) -> Result { }); } FlagsSize::Size4Plus(n) => { - count = n; + count = usize::from(n); as_array = TokenStream::new(); bitor = TokenStream::new(); bitor_assign = TokenStream::new(); @@ -1072,7 +1072,7 @@ fn expand_flags(flags: &Flags) -> Result { .map(|i| { let field = format_ident!("__inner{}", i); - let init = if index / 32 == i { + let init = if index / 32 == usize::from(i) { 1_u32 << (index % 32) } else { 0 diff --git a/crates/component-util/src/lib.rs b/crates/component-util/src/lib.rs index 3aa77df304..db1e4b9768 100644 --- a/crates/component-util/src/lib.rs +++ b/crates/component-util/src/lib.rs @@ -60,7 +60,7 @@ pub enum FlagsSize { /// Flags can fit in a u16 Size2, /// Flags can fit in a specified number of u32 fields - Size4Plus(usize), + Size4Plus(u8), } impl FlagsSize { @@ -73,7 +73,11 @@ impl FlagsSize { } else if count <= 16 { FlagsSize::Size2 } else { - FlagsSize::Size4Plus(ceiling_divide(count, 32)) + let amt = ceiling_divide(count, 32); + if amt > (u8::MAX as usize) { + panic!("too many flags"); + } + FlagsSize::Size4Plus(amt as u8) } } } diff --git a/crates/environ/examples/factc.rs b/crates/environ/examples/factc.rs index 38300a830c..cc14a4ae6a 100644 --- a/crates/environ/examples/factc.rs +++ b/crates/environ/examples/factc.rs @@ -174,7 +174,6 @@ impl Factc { } types.pop_type_scope(); - let types = types.finish(); let mut fact_module = Module::new(&types, self.debug); for (i, adapter) in adapters.iter().enumerate() { fact_module.adapt(&format!("adapter{i}"), adapter); diff --git a/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs b/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs index 2fd52d9004..2c7fe269ed 100644 --- a/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs +++ b/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs @@ -143,7 +143,6 @@ fn target(module: GenAdapterModule) { types.pop_type_scope(); } - let types = types.finish(); let mut fact_module = Module::new(&types, module.debug); for (i, adapter) in adapters.iter().enumerate() { fact_module.adapt(&format!("adapter{i}"), adapter); diff --git a/crates/environ/src/component/translate/adapt.rs b/crates/environ/src/component/translate/adapt.rs index 1b8c6c6830..e8eefdc1fe 100644 --- a/crates/environ/src/component/translate/adapt.rs +++ b/crates/environ/src/component/translate/adapt.rs @@ -184,10 +184,7 @@ impl<'data> Translator<'_, 'data> { // the module using standard core wasm translation, and then fills out // the dfg metadata for each adapter. for (module_id, adapter_module) in state.adapter_modules.iter() { - let mut module = fact::Module::new( - self.types.component_types(), - self.tunables.debug_adapter_modules, - ); + let mut module = fact::Module::new(self.types, self.tunables.debug_adapter_modules); let mut names = Vec::with_capacity(adapter_module.adapters.len()); for adapter in adapter_module.adapters.iter() { let name = format!("adapter{}", adapter.as_u32()); diff --git a/crates/environ/src/component/types.rs b/crates/environ/src/component/types.rs index 5c70f3691f..0c4e47ee75 100644 --- a/crates/environ/src/component/types.rs +++ b/crates/environ/src/component/types.rs @@ -1,3 +1,4 @@ +use crate::component::{MAX_FLAT_PARAMS, MAX_FLAT_RESULTS}; use crate::{ EntityType, Global, GlobalInit, ModuleTypes, ModuleTypesBuilder, PrimaryMap, SignatureIndex, }; @@ -318,6 +319,11 @@ pub struct ComponentTypesBuilder { component_types: ComponentTypes, module_types: ModuleTypesBuilder, + + // Cache of what the "flat" representation of all types are which is only + // used at compile-time and not used at runtime, hence the location here + // as opposed to `ComponentTypes`. + flat: FlatTypesCache, } #[derive(Default)] @@ -326,6 +332,21 @@ struct TypeScope { component: PrimaryMap, } +macro_rules! intern_and_fill_flat_types { + ($me:ident, $name:ident, $val:ident) => {{ + if let Some(idx) = $me.$name.get(&$val) { + return *idx; + } + let idx = $me.component_types.$name.push($val.clone()); + let mut storage = FlatTypesStorage::new(); + storage.$name($me, &$val); + let idx2 = $me.flat.$name.push(storage); + assert_eq!(idx, idx2); + $me.$name.insert($val, idx); + return idx; + }}; +} + impl ComponentTypesBuilder { /// Finishes this list of component types and returns the finished /// structure. @@ -769,42 +790,42 @@ impl ComponentTypesBuilder { /// Interns a new record type within this type information. pub fn add_record_type(&mut self, ty: TypeRecord) -> TypeRecordIndex { - intern(&mut self.records, &mut self.component_types.records, ty) + intern_and_fill_flat_types!(self, records, ty) } /// Interns a new flags type within this type information. pub fn add_flags_type(&mut self, ty: TypeFlags) -> TypeFlagsIndex { - intern(&mut self.flags, &mut self.component_types.flags, ty) + intern_and_fill_flat_types!(self, flags, ty) } /// Interns a new tuple type within this type information. pub fn add_tuple_type(&mut self, ty: TypeTuple) -> TypeTupleIndex { - intern(&mut self.tuples, &mut self.component_types.tuples, ty) + intern_and_fill_flat_types!(self, tuples, ty) } /// Interns a new variant type within this type information. pub fn add_variant_type(&mut self, ty: TypeVariant) -> TypeVariantIndex { - intern(&mut self.variants, &mut self.component_types.variants, ty) + intern_and_fill_flat_types!(self, variants, ty) } /// Interns a new union type within this type information. pub fn add_union_type(&mut self, ty: TypeUnion) -> TypeUnionIndex { - intern(&mut self.unions, &mut self.component_types.unions, ty) + intern_and_fill_flat_types!(self, unions, ty) } /// Interns a new enum type within this type information. pub fn add_enum_type(&mut self, ty: TypeEnum) -> TypeEnumIndex { - intern(&mut self.enums, &mut self.component_types.enums, ty) + intern_and_fill_flat_types!(self, enums, ty) } /// Interns a new option type within this type information. pub fn add_option_type(&mut self, ty: TypeOption) -> TypeOptionIndex { - intern(&mut self.options, &mut self.component_types.options, ty) + intern_and_fill_flat_types!(self, options, ty) } /// Interns a new expected type within this type information. pub fn add_expected_type(&mut self, ty: TypeExpected) -> TypeExpectedIndex { - intern(&mut self.expecteds, &mut self.component_types.expecteds, ty) + intern_and_fill_flat_types!(self, expecteds, ty) } /// Interns a new expected type within this type information. @@ -815,6 +836,43 @@ impl ComponentTypesBuilder { ty, ) } + + /// Returns the canonical ABI information about the specified type. + pub fn canonical_abi(&self, ty: &InterfaceType) -> &CanonicalAbiInfo { + self.component_types.canonical_abi(ty) + } + + /// Returns the "flat types" for the given interface type used in the + /// canonical ABI. + /// + /// Returns `None` if the type is too large to be represented via flat types + /// in the canonical abi. + pub fn flat_types(&self, ty: &InterfaceType) -> Option> { + match ty { + InterfaceType::Unit => Some(FlatTypes::EMPTY), + InterfaceType::U8 + | InterfaceType::S8 + | InterfaceType::Bool + | InterfaceType::U16 + | InterfaceType::S16 + | InterfaceType::U32 + | InterfaceType::S32 + | InterfaceType::Char => Some(FlatTypes::I32), + InterfaceType::U64 | InterfaceType::S64 => Some(FlatTypes::I64), + InterfaceType::Float32 => Some(FlatTypes::F32), + InterfaceType::Float64 => Some(FlatTypes::F64), + InterfaceType::String | InterfaceType::List(_) => Some(FlatTypes::POINTER_PAIR), + + InterfaceType::Record(i) => self.flat.records[*i].as_flat_types(), + InterfaceType::Variant(i) => self.flat.variants[*i].as_flat_types(), + InterfaceType::Tuple(i) => self.flat.tuples[*i].as_flat_types(), + InterfaceType::Flags(i) => self.flat.flags[*i].as_flat_types(), + InterfaceType::Enum(i) => self.flat.enums[*i].as_flat_types(), + InterfaceType::Union(i) => self.flat.unions[*i].as_flat_types(), + InterfaceType::Option(i) => self.flat.options[*i].as_flat_types(), + InterfaceType::Expected(i) => self.flat.expecteds[*i].as_flat_types(), + } + } } // Forward the indexing impl to the internal `TypeTables` @@ -986,6 +1044,13 @@ pub struct CanonicalAbiInfo { pub size64: u32, /// The byte-alignment of this type in a 64-bit memory. pub align64: u32, + /// The number of types it takes to represents this type in the "flat" + /// representation of the canonical abi where everything is passed as + /// immediate arguments or results. + /// + /// If this is `None` then this type is not representable in the flat ABI + /// because it is too large. + pub flat_count: Option, } impl Default for CanonicalAbiInfo { @@ -995,6 +1060,7 @@ impl Default for CanonicalAbiInfo { align32: 1, size64: 0, align64: 1, + flat_count: Some(0), } } } @@ -1019,6 +1085,7 @@ impl CanonicalAbiInfo { align32: 1, size64: 0, align64: 1, + flat_count: Some(0), }; /// ABI information for one-byte scalars. @@ -1036,6 +1103,7 @@ impl CanonicalAbiInfo { align32: size, size64: size, align64: size, + flat_count: Some(1), } } @@ -1045,6 +1113,7 @@ impl CanonicalAbiInfo { align32: 4, size64: 16, align64: 8, + flat_count: Some(2), }; /// Returns the abi for a record represented by the specified fields. @@ -1058,6 +1127,7 @@ impl CanonicalAbiInfo { ret.align32 = ret.align32.max(field.align32); ret.size64 = align_to(ret.size64, field.align64) + field.size64; ret.align64 = ret.align64.max(field.align64); + ret.flat_count = add_flat(ret.flat_count, field.flat_count); } ret.size32 = align_to(ret.size32, ret.align32); ret.size64 = align_to(ret.size64, ret.align64); @@ -1077,6 +1147,7 @@ impl CanonicalAbiInfo { ret.align32 = max(ret.align32, field.align32); ret.size64 = align_to(ret.size64, field.align64) + field.size64; ret.align64 = max(ret.align64, field.align64); + ret.flat_count = add_flat(ret.flat_count, field.flat_count); i += 1; } ret.size32 = align_to(ret.size32, ret.align32); @@ -1116,17 +1187,18 @@ impl CanonicalAbiInfo { /// Returns ABI information for a structure which contains `count` flags. pub const fn flags(count: usize) -> CanonicalAbiInfo { - let (size, align) = match FlagsSize::from_count(count) { - FlagsSize::Size0 => (0, 1), - FlagsSize::Size1 => (1, 1), - FlagsSize::Size2 => (2, 2), - FlagsSize::Size4Plus(n) => ((n as u32) * 4, 4), + let (size, align, flat_count) = match FlagsSize::from_count(count) { + FlagsSize::Size0 => (0, 1, 0), + FlagsSize::Size1 => (1, 1, 1), + FlagsSize::Size2 => (2, 2, 1), + FlagsSize::Size4Plus(n) => ((n as u32) * 4, 4, n), }; CanonicalAbiInfo { size32: size, align32: align, size64: size, align64: align, + flat_count: Some(flat_count), } } @@ -1144,11 +1216,13 @@ impl CanonicalAbiInfo { let mut max_align32 = discrim_size; let mut max_size64 = 0; let mut max_align64 = discrim_size; + let mut max_case_count = Some(0); for case in cases { max_size32 = max_size32.max(case.size32); max_align32 = max_align32.max(case.align32); max_size64 = max_size64.max(case.size64); max_align64 = max_align64.max(case.align64); + max_case_count = max_flat(max_case_count, case.flat_count); } CanonicalAbiInfo { size32: align_to( @@ -1161,6 +1235,7 @@ impl CanonicalAbiInfo { max_align64, ), align64: max_align64, + flat_count: add_flat(max_case_count, Some(1)), } } @@ -1177,6 +1252,7 @@ impl CanonicalAbiInfo { let mut max_align32 = discrim_size; let mut max_size64 = 0; let mut max_align64 = discrim_size; + let mut max_case_count = Some(0); let mut i = 0; while i < cases.len() { let case = &cases[i]; @@ -1184,6 +1260,7 @@ impl CanonicalAbiInfo { max_align32 = max(max_align32, case.align32); max_size64 = max(max_size64, case.size64); max_align64 = max(max_align64, case.align64); + max_case_count = max_flat(max_case_count, case.flat_count); i += 1; } CanonicalAbiInfo { @@ -1197,6 +1274,18 @@ impl CanonicalAbiInfo { max_align64, ), align64: max_align64, + flat_count: add_flat(max_case_count, Some(1)), + } + } + + /// Returns the flat count of this ABI information so long as the count + /// doesn't exceed the `max` specified. + pub fn flat_count(&self, max: usize) -> Option { + let flat = usize::from(self.flat_count?); + if flat > max { + None + } else { + Some(flat) } } } @@ -1396,3 +1485,284 @@ pub struct TypeExpected { /// Byte information about this variant type. pub info: VariantInfo, } + +const MAX_FLAT_TYPES: usize = if MAX_FLAT_PARAMS > MAX_FLAT_RESULTS { + MAX_FLAT_PARAMS +} else { + MAX_FLAT_RESULTS +}; + +const fn add_flat(a: Option, b: Option) -> Option { + const MAX: u8 = MAX_FLAT_TYPES as u8; + let sum = match (a, b) { + (Some(a), Some(b)) => match a.checked_add(b) { + Some(c) => c, + None => return None, + }, + _ => return None, + }; + if sum > MAX { + None + } else { + Some(sum) + } +} + +const fn max_flat(a: Option, b: Option) -> Option { + match (a, b) { + (Some(a), Some(b)) => { + if a > b { + Some(a) + } else { + Some(b) + } + } + _ => None, + } +} + +/// Flat representation of a type in just core wasm types. +pub struct FlatTypes<'a> { + /// The flat representation of this type in 32-bit memories. + pub memory32: &'a [FlatType], + /// The flat representation of this type in 64-bit memories. + pub memory64: &'a [FlatType], +} + +#[allow(missing_docs)] +impl FlatTypes<'_> { + pub const EMPTY: FlatTypes<'static> = FlatTypes::new(&[]); + pub const I32: FlatTypes<'static> = FlatTypes::new(&[FlatType::I32]); + pub const I64: FlatTypes<'static> = FlatTypes::new(&[FlatType::I64]); + pub const F32: FlatTypes<'static> = FlatTypes::new(&[FlatType::F32]); + pub const F64: FlatTypes<'static> = FlatTypes::new(&[FlatType::F64]); + pub const POINTER_PAIR: FlatTypes<'static> = FlatTypes { + memory32: &[FlatType::I32, FlatType::I32], + memory64: &[FlatType::I64, FlatType::I64], + }; + + const fn new(flat: &[FlatType]) -> FlatTypes<'_> { + FlatTypes { + memory32: flat, + memory64: flat, + } + } + + /// Returns the number of flat types used to represent this type. + /// + /// Note that this length is the same regardless to the size of memory. + pub fn len(&self) -> usize { + assert_eq!(self.memory32.len(), self.memory64.len()); + self.memory32.len() + } +} + +// Note that this is intentionally duplicated here to keep the size to 1 byte +// irregardless to changes in the core wasm type system since this will only +// ever use integers/floats for the forseeable future. +#[derive(PartialEq, Eq, Copy, Clone)] +#[allow(missing_docs)] +pub enum FlatType { + I32, + I64, + F32, + F64, +} + +#[derive(Default)] +struct FlatTypesCache { + records: PrimaryMap, + variants: PrimaryMap, + tuples: PrimaryMap, + enums: PrimaryMap, + flags: PrimaryMap, + unions: PrimaryMap, + options: PrimaryMap, + expecteds: PrimaryMap, +} + +struct FlatTypesStorage { + // This could be represented as `Vec` but on 64-bit architectures + // that's 24 bytes. Otherwise `FlatType` is 1 byte large and + // `MAX_FLAT_TYPES` is 16, so it should ideally be more space-efficient to + // use a flat array instead of a heap-based vector. + memory32: [FlatType; MAX_FLAT_TYPES], + memory64: [FlatType; MAX_FLAT_TYPES], + + // Tracks the number of flat types pushed into this storage. If this is + // `MAX_FLAT_TYPES + 1` then this storage represents an un-reprsentable + // type in flat types. + len: u8, +} + +impl FlatTypesStorage { + fn new() -> FlatTypesStorage { + FlatTypesStorage { + memory32: [FlatType::I32; MAX_FLAT_TYPES], + memory64: [FlatType::I32; MAX_FLAT_TYPES], + len: 0, + } + } + + fn as_flat_types(&self) -> Option> { + let len = usize::from(self.len); + if len > MAX_FLAT_TYPES { + assert_eq!(len, MAX_FLAT_TYPES + 1); + None + } else { + Some(FlatTypes { + memory32: &self.memory32[..len], + memory64: &self.memory64[..len], + }) + } + } + + /// Pushes a new flat type into this list using `t32` for 32-bit memories + /// and `t64` for 64-bit memories. + /// + /// Returns whether the type was actually pushed or whether this list of + /// flat types just exceeded the maximum meaning that it is now + /// unrepresentable with a flat list of types. + fn push(&mut self, t32: FlatType, t64: FlatType) -> bool { + let len = usize::from(self.len); + if len < MAX_FLAT_TYPES { + self.memory32[len] = t32; + self.memory64[len] = t64; + self.len += 1; + true + } else { + // If this was the first one to go over then flag the length as + // being incompatible with a flat representation. + if len == MAX_FLAT_TYPES { + self.len += 1; + } + false + } + } + + /// Builds up all flat types internally using the specified representation + /// for all of the component fields of the record. + fn build_record<'a>(&mut self, types: impl Iterator>>) { + for ty in types { + let types = match ty { + Some(types) => types, + None => { + self.len = u8::try_from(MAX_FLAT_TYPES + 1).unwrap(); + return; + } + }; + for (t32, t64) in types.memory32.iter().zip(types.memory64) { + if !self.push(*t32, *t64) { + return; + } + } + } + } + + /// Builds up the flat types used to represent a `variant` which notably + /// handles "join"ing types together so each case is representable as a + /// single flat list of types. + fn build_variant<'a, I>(&mut self, cases: I) + where + I: IntoIterator>>, + { + let cases = cases.into_iter(); + self.push(FlatType::I32, FlatType::I32); + + for ty in cases { + let types = match ty { + Some(types) => types, + // If this case isn't representable with a flat list of types + // then this variant also isn't representable. + None => { + self.len = u8::try_from(MAX_FLAT_TYPES + 1).unwrap(); + return; + } + }; + // If the case used all of the flat types then the discriminant + // added for this variant means that this variant is no longer + // representable. + if types.memory32.len() >= MAX_FLAT_TYPES { + self.len = u8::try_from(MAX_FLAT_TYPES + 1).unwrap(); + return; + } + let dst = self.memory32.iter_mut().zip(&mut self.memory64).skip(1); + for (i, ((t32, t64), (dst32, dst64))) in types + .memory32 + .iter() + .zip(types.memory64) + .zip(dst) + .enumerate() + { + if i + 1 < usize::from(self.len) { + // If this index hs already been set by some previous case + // then the types are joined together. + dst32.join(*t32); + dst64.join(*t64); + } else { + // Otherwise if this is the first time that the + // representation has gotten this large then the destination + // is simply whatever the type is. The length is also + // increased here to indicate this. + self.len += 1; + *dst32 = *t32; + *dst64 = *t64; + } + } + } + } + + fn records(&mut self, types: &ComponentTypesBuilder, ty: &TypeRecord) { + self.build_record(ty.fields.iter().map(|f| types.flat_types(&f.ty))); + } + + fn tuples(&mut self, types: &ComponentTypesBuilder, ty: &TypeTuple) { + self.build_record(ty.types.iter().map(|t| types.flat_types(t))); + } + + fn enums(&mut self, _types: &ComponentTypesBuilder, _ty: &TypeEnum) { + self.push(FlatType::I32, FlatType::I32); + } + + fn flags(&mut self, _types: &ComponentTypesBuilder, ty: &TypeFlags) { + match FlagsSize::from_count(ty.names.len()) { + FlagsSize::Size0 => {} + FlagsSize::Size1 | FlagsSize::Size2 => { + self.push(FlatType::I32, FlatType::I32); + } + FlagsSize::Size4Plus(n) => { + for _ in 0..n { + self.push(FlatType::I32, FlatType::I32); + } + } + } + } + + fn variants(&mut self, types: &ComponentTypesBuilder, ty: &TypeVariant) { + self.build_variant(ty.cases.iter().map(|c| types.flat_types(&c.ty))) + } + + fn unions(&mut self, types: &ComponentTypesBuilder, ty: &TypeUnion) { + self.build_variant(ty.types.iter().map(|t| types.flat_types(t))) + } + + fn expecteds(&mut self, types: &ComponentTypesBuilder, ty: &TypeExpected) { + self.build_variant([types.flat_types(&ty.ok), types.flat_types(&ty.err)]); + } + + fn options(&mut self, types: &ComponentTypesBuilder, ty: &TypeOption) { + self.build_variant([Some(FlatTypes::EMPTY), types.flat_types(&ty.ty)]); + } +} + +impl FlatType { + fn join(&mut self, other: FlatType) { + if *self == other { + return; + } + *self = match (*self, other) { + (FlatType::I32, FlatType::F32) | (FlatType::F32, FlatType::I32) => FlatType::I32, + _ => FlatType::I64, + }; + } +} diff --git a/crates/environ/src/fact.rs b/crates/environ/src/fact.rs index 895db3ea5c..2827b21249 100644 --- a/crates/environ/src/fact.rs +++ b/crates/environ/src/fact.rs @@ -20,8 +20,8 @@ use crate::component::dfg::CoreDef; use crate::component::{ - Adapter, AdapterOptions as AdapterOptionsDfg, ComponentTypes, InterfaceType, StringEncoding, - TypeFuncIndex, + Adapter, AdapterOptions as AdapterOptionsDfg, ComponentTypesBuilder, InterfaceType, + StringEncoding, TypeFuncIndex, }; use crate::fact::transcode::Transcoder; use crate::{EntityRef, FuncIndex, GlobalIndex, MemoryIndex, PrimaryMap}; @@ -41,7 +41,7 @@ pub struct Module<'a> { /// Whether or not debug code is inserted into the adapters themselves. debug: bool, /// Type information from the creator of this `Module` - types: &'a ComponentTypes, + types: &'a ComponentTypesBuilder, /// Core wasm type section that's incrementally built core_types: core_types::CoreTypes, @@ -125,7 +125,7 @@ enum Context { impl<'a> Module<'a> { /// Creates an empty module. - pub fn new(types: &'a ComponentTypes, debug: bool) -> Module<'a> { + pub fn new(types: &'a ComponentTypesBuilder, debug: bool) -> Module<'a> { Module { debug, types, diff --git a/crates/environ/src/fact/signature.rs b/crates/environ/src/fact/signature.rs index 7440831b14..c93836c1fa 100644 --- a/crates/environ/src/fact/signature.rs +++ b/crates/environ/src/fact/signature.rs @@ -1,9 +1,10 @@ //! Size, align, and flattening information about component model types. -use crate::component::{ComponentTypes, InterfaceType, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS}; +use crate::component::{ + ComponentTypesBuilder, FlatType, InterfaceType, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, +}; use crate::fact::{AdapterOptions, Context, Options}; use wasm_encoder::ValType; -use wasmtime_component_util::FlagsSize; /// Metadata about a core wasm signature which is created for a component model /// signature. @@ -23,7 +24,7 @@ pub struct Signature { pub results_indirect: bool, } -impl ComponentTypes { +impl ComponentTypesBuilder { /// Calculates the core wasm function signature for the component function /// type specified within `Context`. /// @@ -34,31 +35,39 @@ impl ComponentTypes { let ty = &self[options.ty]; let ptr_ty = options.options.ptr(); - let mut params = self.flatten_types(&options.options, ty.params.iter().map(|(_, ty)| *ty)); let mut params_indirect = false; - if params.len() > MAX_FLAT_PARAMS { - params = vec![ptr_ty]; - params_indirect = true; - } + let mut params = match self.flatten_types( + &options.options, + MAX_FLAT_PARAMS, + ty.params.iter().map(|(_, ty)| *ty), + ) { + Some(list) => list, + None => { + params_indirect = true; + vec![ptr_ty] + } + }; - let mut results = self.flatten_types(&options.options, [ty.result]); let mut results_indirect = false; - if results.len() > MAX_FLAT_RESULTS { - results_indirect = true; - match context { - // For a lifted function too-many-results gets translated to a - // returned pointer where results are read from. The callee - // allocates space here. - Context::Lift => results = vec![ptr_ty], - // For a lowered function too-many-results becomes a return - // pointer which is passed as the last argument. The caller - // allocates space here. - Context::Lower => { - results.truncate(0); - params.push(ptr_ty); + let results = match self.flatten_types(&options.options, MAX_FLAT_RESULTS, [ty.result]) { + Some(list) => list, + None => { + results_indirect = true; + match context { + // For a lifted function too-many-results gets translated to a + // returned pointer where results are read from. The callee + // allocates space here. + Context::Lift => vec![ptr_ty], + // For a lowered function too-many-results becomes a return + // pointer which is passed as the last argument. The caller + // allocates space here. + Context::Lower => { + params.push(ptr_ty); + Vec::new() + } } } - } + }; Signature { params, results, @@ -72,115 +81,31 @@ impl ComponentTypes { pub(super) fn flatten_types( &self, opts: &Options, + max: usize, tys: impl IntoIterator, - ) -> Vec { - let mut result = Vec::new(); + ) -> Option> { + let mut dst = Vec::new(); for ty in tys { - self.push_flat(opts, &ty, &mut result); - } - result - } - - fn push_flat(&self, opts: &Options, ty: &InterfaceType, dst: &mut Vec) { - match ty { - InterfaceType::Unit => {} - - InterfaceType::Bool - | InterfaceType::S8 - | InterfaceType::U8 - | InterfaceType::S16 - | InterfaceType::U16 - | InterfaceType::S32 - | InterfaceType::U32 - | InterfaceType::Char => dst.push(ValType::I32), - - InterfaceType::S64 | InterfaceType::U64 => dst.push(ValType::I64), - - InterfaceType::Float32 => dst.push(ValType::F32), - InterfaceType::Float64 => dst.push(ValType::F64), - - InterfaceType::String | InterfaceType::List(_) => { - dst.push(opts.ptr()); - dst.push(opts.ptr()); - } - InterfaceType::Record(r) => { - for field in self[*r].fields.iter() { - self.push_flat(opts, &field.ty, dst); + let flat = self.flat_types(&ty)?; + let types = if opts.memory64 { + flat.memory64 + } else { + flat.memory32 + }; + for ty in types { + let ty = match ty { + FlatType::I32 => ValType::I32, + FlatType::I64 => ValType::I64, + FlatType::F32 => ValType::F32, + FlatType::F64 => ValType::F64, + }; + if dst.len() == max { + return None; } - } - InterfaceType::Tuple(t) => { - for ty in self[*t].types.iter() { - self.push_flat(opts, ty, dst); - } - } - InterfaceType::Flags(f) => { - let flags = &self[*f]; - match FlagsSize::from_count(flags.names.len()) { - FlagsSize::Size0 => {} - FlagsSize::Size1 | FlagsSize::Size2 => dst.push(ValType::I32), - FlagsSize::Size4Plus(n) => { - dst.extend((0..n).map(|_| ValType::I32)); - } - } - } - InterfaceType::Enum(_) => dst.push(ValType::I32), - InterfaceType::Option(t) => { - dst.push(ValType::I32); - self.push_flat(opts, &self[*t].ty, dst); - } - InterfaceType::Variant(t) => { - dst.push(ValType::I32); - let pos = dst.len(); - let mut tmp = Vec::new(); - for case in self[*t].cases.iter() { - self.push_flat_variant(opts, &case.ty, pos, &mut tmp, dst); - } - } - InterfaceType::Union(t) => { - dst.push(ValType::I32); - let pos = dst.len(); - let mut tmp = Vec::new(); - for ty in self[*t].types.iter() { - self.push_flat_variant(opts, ty, pos, &mut tmp, dst); - } - } - InterfaceType::Expected(t) => { - dst.push(ValType::I32); - let e = &self[*t]; - let pos = dst.len(); - let mut tmp = Vec::new(); - self.push_flat_variant(opts, &e.ok, pos, &mut tmp, dst); - self.push_flat_variant(opts, &e.err, pos, &mut tmp, dst); - } - } - } - - fn push_flat_variant( - &self, - opts: &Options, - ty: &InterfaceType, - pos: usize, - tmp: &mut Vec, - dst: &mut Vec, - ) { - tmp.truncate(0); - self.push_flat(opts, ty, tmp); - for (i, a) in tmp.iter().enumerate() { - match dst.get_mut(pos + i) { - Some(b) => join(*a, b), - None => dst.push(*a), - } - } - - fn join(a: ValType, b: &mut ValType) { - if a == *b { - return; - } - match (a, *b) { - (ValType::I32, ValType::F32) | (ValType::F32, ValType::I32) => *b = ValType::I32, - _ => *b = ValType::I64, + dst.push(ty); } } + Some(dst) } pub(super) fn align(&self, opts: &Options, ty: &InterfaceType) -> u32 { diff --git a/crates/environ/src/fact/trampoline.rs b/crates/environ/src/fact/trampoline.rs index faaf32d3ea..bedc8d5781 100644 --- a/crates/environ/src/fact/trampoline.rs +++ b/crates/environ/src/fact/trampoline.rs @@ -16,7 +16,7 @@ //! can be somewhat arbitrary, an intentional decision. use crate::component::{ - CanonicalAbiInfo, ComponentTypes, InterfaceType, StringEncoding, TypeEnumIndex, + CanonicalAbiInfo, ComponentTypesBuilder, InterfaceType, StringEncoding, TypeEnumIndex, TypeExpectedIndex, TypeFlagsIndex, TypeInterfaceIndex, TypeOptionIndex, TypeRecordIndex, TypeTupleIndex, TypeUnionIndex, TypeVariantIndex, VariantInfo, FLAG_MAY_ENTER, FLAG_MAY_LEAVE, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, @@ -36,7 +36,7 @@ const MAX_STRING_BYTE_LENGTH: u32 = 1 << 31; const UTF16_TAG: u32 = 1 << 31; struct Compiler<'a, 'b> { - types: &'a ComponentTypes, + types: &'a ComponentTypesBuilder, module: &'b mut Module<'a>, result: FunctionId, @@ -279,14 +279,16 @@ impl Compiler<'_, '_> { // TODO: handle subtyping assert_eq!(src_tys.len(), dst_tys.len()); - let src_flat = self - .types - .flatten_types(lower_opts, src_tys.iter().copied()); - let dst_flat = self.types.flatten_types(lift_opts, dst_tys.iter().copied()); + let src_flat = + self.types + .flatten_types(lower_opts, MAX_FLAT_PARAMS, src_tys.iter().copied()); + let dst_flat = + self.types + .flatten_types(lift_opts, MAX_FLAT_PARAMS, dst_tys.iter().copied()); - let src = if src_flat.len() <= MAX_FLAT_PARAMS { + let src = if let Some(flat) = &src_flat { Source::Stack(Stack { - locals: ¶m_locals[..src_flat.len()], + locals: ¶m_locals[..flat.len()], opts: lower_opts, }) } else { @@ -303,8 +305,8 @@ impl Compiler<'_, '_> { Source::Memory(self.memory_operand(lower_opts, TempLocal::new(addr, ty), align)) }; - let dst = if dst_flat.len() <= MAX_FLAT_PARAMS { - Destination::Stack(&dst_flat, lift_opts) + let dst = if let Some(flat) = &dst_flat { + Destination::Stack(flat, lift_opts) } else { // If there are too many parameters then space is allocated in the // destination module for the parameters via its `realloc` function. @@ -348,10 +350,14 @@ impl Compiler<'_, '_> { let lift_opts = &adapter.lift.options; let lower_opts = &adapter.lower.options; - let src_flat = self.types.flatten_types(lift_opts, [src_ty]); - let dst_flat = self.types.flatten_types(lower_opts, [dst_ty]); + let src_flat = self + .types + .flatten_types(lift_opts, MAX_FLAT_RESULTS, [src_ty]); + let dst_flat = self + .types + .flatten_types(lower_opts, MAX_FLAT_RESULTS, [dst_ty]); - let src = if src_flat.len() <= MAX_FLAT_RESULTS { + let src = if src_flat.is_some() { Source::Stack(Stack { locals: result_locals, opts: lift_opts, @@ -368,8 +374,8 @@ impl Compiler<'_, '_> { Source::Memory(self.memory_operand(lift_opts, TempLocal::new(addr, ty), align)) }; - let dst = if dst_flat.len() <= MAX_FLAT_RESULTS { - Destination::Stack(&dst_flat, lower_opts) + let dst = if let Some(flat) = &dst_flat { + Destination::Stack(flat, lower_opts) } else { // This is slightly different than `translate_params` where the // return pointer was provided by the caller of this function @@ -1937,6 +1943,7 @@ impl Compiler<'_, '_> { FlagsSize::Size4Plus(n) => { let srcs = src.record_field_srcs(self.types, (0..n).map(|_| InterfaceType::U32)); let dsts = dst.record_field_dsts(self.types, (0..n).map(|_| InterfaceType::U32)); + let n = usize::from(n); for (i, (src, dst)) in srcs.zip(dsts).enumerate() { let mask = if i == n - 1 && (cnt % 32 != 0) { (1 << (cnt % 32)) - 1 @@ -2775,7 +2782,7 @@ impl<'a> Source<'a> { /// offset for each memory-based type. fn record_field_srcs<'b>( &'b self, - types: &'b ComponentTypes, + types: &'b ComponentTypesBuilder, fields: impl IntoIterator + 'b, ) -> impl Iterator> + 'b where @@ -2788,7 +2795,7 @@ impl<'a> Source<'a> { Source::Memory(mem) } Source::Stack(stack) => { - let cnt = types.flatten_types(stack.opts, [ty]).len() as u32; + let cnt = types.flat_types(&ty).unwrap().len() as u32; offset += cnt; Source::Stack(stack.slice((offset - cnt) as usize..offset as usize)) } @@ -2798,13 +2805,13 @@ impl<'a> Source<'a> { /// Returns the corresponding discriminant source and payload source f fn payload_src( &self, - types: &ComponentTypes, + types: &ComponentTypesBuilder, info: &VariantInfo, case: &InterfaceType, ) -> Source<'a> { match self { Source::Stack(s) => { - let flat_len = types.flatten_types(s.opts, [*case]).len(); + let flat_len = types.flat_types(case).unwrap().len(); Source::Stack(s.slice(1..s.locals.len()).slice(0..flat_len)) } Source::Memory(mem) => { @@ -2830,7 +2837,7 @@ impl<'a> Destination<'a> { /// Same as `Source::record_field_srcs` but for destinations. fn record_field_dsts<'b>( &'b self, - types: &'b ComponentTypes, + types: &'b ComponentTypesBuilder, fields: impl IntoIterator + 'b, ) -> impl Iterator + 'b where @@ -2843,7 +2850,7 @@ impl<'a> Destination<'a> { Destination::Memory(mem) } Destination::Stack(s, opts) => { - let cnt = types.flatten_types(opts, [ty]).len() as u32; + let cnt = types.flat_types(&ty).unwrap().len() as u32; offset += cnt; Destination::Stack(&s[(offset - cnt) as usize..offset as usize], opts) } @@ -2853,13 +2860,13 @@ impl<'a> Destination<'a> { /// Returns the corresponding discriminant source and payload source f fn payload_dst( &self, - types: &ComponentTypes, + types: &ComponentTypesBuilder, info: &VariantInfo, case: &InterfaceType, ) -> Destination { match self { Destination::Stack(s, opts) => { - let flat_len = types.flatten_types(opts, [*case]).len(); + let flat_len = types.flat_types(case).unwrap().len(); Destination::Stack(&s[1..][..flat_len], opts) } Destination::Memory(mem) => { @@ -2883,7 +2890,7 @@ impl<'a> Destination<'a> { fn next_field_offset<'a>( offset: &mut u32, - types: &ComponentTypes, + types: &ComponentTypesBuilder, field: &InterfaceType, mem: &Memory<'a>, ) -> Memory<'a> { @@ -2930,7 +2937,7 @@ struct VariantCase<'a> { dst_ty: &'a InterfaceType, } -fn variant_info(types: &ComponentTypes, cases: I) -> VariantInfo +fn variant_info(types: &ComponentTypesBuilder, cases: I) -> VariantInfo where I: IntoIterator, I::IntoIter: ExactSizeIterator, diff --git a/crates/misc/component-fuzz-util/src/lib.rs b/crates/misc/component-fuzz-util/src/lib.rs index e011b4888d..8387d29d9c 100644 --- a/crates/misc/component-fuzz-util/src/lib.rs +++ b/crates/misc/component-fuzz-util/src/lib.rs @@ -166,7 +166,7 @@ fn u32_count_from_flag_count(count: usize) -> usize { match FlagsSize::from_count(count) { FlagsSize::Size0 => 0, FlagsSize::Size1 | FlagsSize::Size2 => 1, - FlagsSize::Size4Plus(n) => n, + FlagsSize::Size4Plus(n) => n.into(), } } @@ -270,7 +270,7 @@ impl Type { alignment: 2, }, FlagsSize::Size4Plus(n) => SizeAndAlignment { - size: n * 4, + size: usize::from(n) * 4, alignment: 4, }, }, diff --git a/crates/wasmtime/src/component/func.rs b/crates/wasmtime/src/component/func.rs index 28bc5c1601..396d6d83c1 100644 --- a/crates/wasmtime/src/component/func.rs +++ b/crates/wasmtime/src/component/func.rs @@ -302,15 +302,15 @@ impl Func { result = Type::from(&ty.result, &data.types); } - let param_count = params.iter().map(|ty| ty.flatten_count()).sum::(); - let result_count = result.flatten_count(); + let param_abi = CanonicalAbiInfo::record(params.iter().map(|t| t.canonical_abi())); + let result_count = result.canonical_abi().flat_count(MAX_FLAT_RESULTS); self.call_raw( store, args, |store, options, args, dst: &mut MaybeUninit<[ValRaw; MAX_FLAT_PARAMS]>| { - if param_count > MAX_FLAT_PARAMS { - self.store_args(store, &options, ¶ms, args, dst) + if param_abi.flat_count(MAX_FLAT_PARAMS).is_none() { + self.store_args(store, &options, ¶m_abi, ¶ms, args, dst) } else { dst.write([ValRaw::u64(0); MAX_FLAT_PARAMS]); @@ -324,7 +324,7 @@ impl Func { } }, |store, options, src: &[ValRaw; MAX_FLAT_RESULTS]| { - if result_count > MAX_FLAT_RESULTS { + if result_count.is_none() { Self::load_result(&Memory::new(store, &options), &result, &mut src.iter()) } else { Val::lift(&result, store, &options, &mut src.iter()) @@ -554,12 +554,11 @@ impl Func { &self, store: &mut StoreContextMut<'_, T>, options: &Options, + abi: &CanonicalAbiInfo, params: &[Type], args: &[Val], dst: &mut MaybeUninit<[ValRaw; MAX_FLAT_PARAMS]>, ) -> Result<()> { - let abi = CanonicalAbiInfo::record(params.iter().map(|t| t.canonical_abi())); - let mut memory = MemoryMut::new(store.as_context_mut(), options); let size = usize::try_from(abi.size32).unwrap(); let ptr = memory.realloc(0, 0, abi.align32, size)?; diff --git a/crates/wasmtime/src/component/func/host.rs b/crates/wasmtime/src/component/func/host.rs index 6f3c450cb1..1d6b5d5c20 100644 --- a/crates/wasmtime/src/component/func/host.rs +++ b/crates/wasmtime/src/component/func/host.rs @@ -400,12 +400,19 @@ where bail!("cannot leave component instance"); } - let param_count = params.iter().map(|ty| ty.flatten_count()).sum::(); - let args; let ret_index; - if param_count <= MAX_FLAT_PARAMS { + let param_abi = CanonicalAbiInfo::record(params.iter().map(|t| t.canonical_abi())); + let param_count = param_abi.flat_count.and_then(|i| { + let i = usize::from(i); + if i > MAX_FLAT_PARAMS { + None + } else { + Some(i) + } + }); + if let Some(param_count) = param_count { let iter = &mut storage.iter(); args = params .iter() @@ -413,8 +420,6 @@ where .collect::>>()?; ret_index = param_count; } else { - let param_abi = CanonicalAbiInfo::record(params.iter().map(|t| t.canonical_abi())); - let memory = Memory::new(cx.0, &options); let mut offset = validate_inbounds_dynamic(¶m_abi, memory.as_slice(), &storage[0])?; args = params @@ -436,8 +441,8 @@ where flags.set_may_leave(false); result.check(&ret)?; - let result_count = result.flatten_count(); - if result_count <= MAX_FLAT_RESULTS { + let result_count = result.canonical_abi().flat_count(MAX_FLAT_RESULTS); + if result_count.is_some() { let dst = mem::transmute::<&mut [ValRaw], &mut [MaybeUninit]>(storage); ret.lower(&mut cx, &options, &mut dst.iter_mut())?; } else { diff --git a/crates/wasmtime/src/component/types.rs b/crates/wasmtime/src/component/types.rs index cc065d72fc..1b30135ae5 100644 --- a/crates/wasmtime/src/component/types.rs +++ b/crates/wasmtime/src/component/types.rs @@ -78,6 +78,10 @@ impl Record { ty: Type::from(&field.ty, &self.0.types), }) } + + pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo { + &self.0.types[self.0.index].abi + } } /// A `tuple` interface type @@ -97,6 +101,10 @@ impl Tuple { .iter() .map(|ty| Type::from(ty, &self.0.types)) } + + pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo { + &self.0.types[self.0.index].abi + } } /// A case declaration belonging to a `variant` @@ -128,6 +136,10 @@ impl Variant { pub(crate) fn variant_info(&self) -> &VariantInfo { &self.0.types[self.0.index].info } + + pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo { + &self.0.types[self.0.index].abi + } } /// An `enum` interface type @@ -151,6 +163,10 @@ impl Enum { pub(crate) fn variant_info(&self) -> &VariantInfo { &self.0.types[self.0.index].info } + + pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo { + &self.0.types[self.0.index].abi + } } /// A `union` interface type @@ -174,6 +190,10 @@ impl Union { pub(crate) fn variant_info(&self) -> &VariantInfo { &self.0.types[self.0.index].info } + + pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo { + &self.0.types[self.0.index].abi + } } /// An `option` interface type @@ -194,6 +214,10 @@ impl Option { pub(crate) fn variant_info(&self) -> &VariantInfo { &self.0.types[self.0.index].info } + + pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo { + &self.0.types[self.0.index].abi + } } /// An `expected` interface type @@ -219,6 +243,10 @@ impl Expected { pub(crate) fn variant_info(&self) -> &VariantInfo { &self.0.types[self.0.index].info } + + pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo { + &self.0.types[self.0.index].abi + } } /// A `flags` interface type @@ -238,6 +266,10 @@ impl Flags { .iter() .map(|name| name.deref()) } + + pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo { + &self.0.types[self.0.index].abi + } } /// Represents a component model interface type @@ -483,60 +515,6 @@ impl Type { } } - /// Return the number of stack slots needed to store values of this type in lowered form. - pub(crate) fn flatten_count(&self) -> usize { - match self { - Type::Unit => 0, - - Type::Bool - | Type::S8 - | Type::U8 - | Type::S16 - | Type::U16 - | Type::S32 - | Type::U32 - | Type::S64 - | Type::U64 - | Type::Float32 - | Type::Float64 - | Type::Char - | Type::Enum(_) => 1, - - Type::String | Type::List(_) => 2, - - Type::Record(handle) => handle.fields().map(|field| field.ty.flatten_count()).sum(), - - Type::Tuple(handle) => handle.types().map(|ty| ty.flatten_count()).sum(), - - Type::Variant(handle) => { - 1 + handle - .cases() - .map(|case| case.ty.flatten_count()) - .max() - .unwrap_or(0) - } - - Type::Union(handle) => { - 1 + handle - .types() - .map(|ty| ty.flatten_count()) - .max() - .unwrap_or(0) - } - - Type::Option(handle) => 1 + handle.ty().flatten_count(), - - Type::Expected(handle) => { - 1 + handle - .ok() - .flatten_count() - .max(handle.err().flatten_count()) - } - - Type::Flags(handle) => values::u32_count_for_flag_count(handle.names().len()), - } - } - fn desc(&self) -> &'static str { match self { Type::Unit => "unit", @@ -574,14 +552,14 @@ impl Type { Type::S32 | Type::U32 | Type::Char | Type::Float32 => &CanonicalAbiInfo::SCALAR4, Type::S64 | Type::U64 | Type::Float64 => &CanonicalAbiInfo::SCALAR8, Type::String | Type::List(_) => &CanonicalAbiInfo::POINTER_PAIR, - Type::Record(handle) => &handle.0.types[handle.0.index].abi, - Type::Tuple(handle) => &handle.0.types[handle.0.index].abi, - Type::Variant(handle) => &handle.0.types[handle.0.index].abi, - Type::Enum(handle) => &handle.0.types[handle.0.index].abi, - Type::Union(handle) => &handle.0.types[handle.0.index].abi, - Type::Option(handle) => &handle.0.types[handle.0.index].abi, - Type::Expected(handle) => &handle.0.types[handle.0.index].abi, - Type::Flags(handle) => &handle.0.types[handle.0.index].abi, + Type::Record(handle) => handle.canonical_abi(), + Type::Tuple(handle) => handle.canonical_abi(), + Type::Variant(handle) => handle.canonical_abi(), + Type::Enum(handle) => handle.canonical_abi(), + Type::Union(handle) => handle.canonical_abi(), + Type::Option(handle) => handle.canonical_abi(), + Type::Expected(handle) => handle.canonical_abi(), + Type::Flags(handle) => handle.canonical_abi(), } } } diff --git a/crates/wasmtime/src/component/values.rs b/crates/wasmtime/src/component/values.rs index 3c85f857ce..fefffcbfe2 100644 --- a/crates/wasmtime/src/component/values.rs +++ b/crates/wasmtime/src/component/values.rs @@ -440,7 +440,8 @@ impl Flags { .map(|(index, name)| (name, index)) .collect::>(); - let mut values = vec![0_u32; u32_count_for_flag_count(ty.names().len())]; + let count = usize::from(ty.canonical_abi().flat_count.unwrap()); + let mut values = vec![0_u32; count]; for name in names { let index = map @@ -611,7 +612,7 @@ impl Val { }), Type::Variant(handle) => { let (discriminant, value) = lift_variant( - ty.flatten_count(), + handle.canonical_abi().flat_count(usize::MAX).unwrap(), handle.cases().map(|case| case.ty), store, options, @@ -626,7 +627,7 @@ impl Val { } Type::Enum(handle) => { let (discriminant, _) = lift_variant( - ty.flatten_count(), + handle.canonical_abi().flat_count(usize::MAX).unwrap(), handle.names().map(|_| Type::Unit), store, options, @@ -639,8 +640,13 @@ impl Val { }) } Type::Union(handle) => { - let (discriminant, value) = - lift_variant(ty.flatten_count(), handle.types(), store, options, src)?; + let (discriminant, value) = lift_variant( + handle.canonical_abi().flat_count(usize::MAX).unwrap(), + handle.types(), + store, + options, + src, + )?; Val::Union(Union { ty: handle.clone(), @@ -650,7 +656,7 @@ impl Val { } Type::Option(handle) => { let (discriminant, value) = lift_variant( - ty.flatten_count(), + handle.canonical_abi().flat_count(usize::MAX).unwrap(), [Type::Unit, handle.ty()].into_iter(), store, options, @@ -665,7 +671,7 @@ impl Val { } Type::Expected(handle) => { let (discriminant, value) = lift_variant( - ty.flatten_count(), + handle.canonical_abi().flat_count(usize::MAX).unwrap(), [handle.ok(), handle.err()].into_iter(), store, options, @@ -680,8 +686,9 @@ impl Val { } Type::Flags(handle) => { let count = u32::try_from(handle.names().len()).unwrap(); + let u32_count = handle.canonical_abi().flat_count(usize::MAX).unwrap(); let value = iter::repeat_with(|| u32::lift(store, options, next(src))) - .take(u32_count_for_flag_count(count.try_into()?)) + .take(u32_count) .collect::>()?; Val::Flags(Flags { @@ -797,7 +804,7 @@ impl Val { FlagsSize::Size1 => iter::once(u8::load(mem, bytes)? as u32).collect(), FlagsSize::Size2 => iter::once(u16::load(mem, bytes)? as u32).collect(), FlagsSize::Size4Plus(n) => (0..n) - .map(|index| u32::load(mem, &bytes[index * 4..][..4])) + .map(|index| u32::load(mem, &bytes[usize::from(index) * 4..][..4])) .collect::>()?, }, }), @@ -868,7 +875,9 @@ impl Val { }) => { next_mut(dst).write(ValRaw::u32(*discriminant)); value.lower(store, options, dst)?; - for _ in (1 + value.ty().flatten_count())..self.ty().flatten_count() { + let value_flat = value.ty().canonical_abi().flat_count(usize::MAX).unwrap(); + let variant_flat = self.ty().canonical_abi().flat_count(usize::MAX).unwrap(); + for _ in (1 + value_flat)..variant_flat { next_mut(dst).write(ValRaw::u32(0)); } } @@ -1070,7 +1079,8 @@ fn lift_variant<'a>( .nth(discriminant as usize) .ok_or_else(|| anyhow!("discriminant {} out of range [0..{})", discriminant, len))?; let value = Val::lift(&ty, store, options, src)?; - for _ in (1 + ty.flatten_count())..flatten_count { + let value_flat = ty.canonical_abi().flat_count(usize::MAX).unwrap(); + for _ in (1 + value_flat)..flatten_count { next(src); } Ok((discriminant, value)) @@ -1098,17 +1108,6 @@ fn lower_list( Ok((ptr, items.len())) } -/// Calculate the size of a u32 array needed to represent the specified number of bit flags. -/// -/// Note that this will always return at least 1, even if the `count` parameter is zero. -pub(crate) fn u32_count_for_flag_count(count: usize) -> usize { - match FlagsSize::from_count(count) { - FlagsSize::Size0 => 0, - FlagsSize::Size1 | FlagsSize::Size2 => 1, - FlagsSize::Size4Plus(n) => n, - } -} - fn next<'a>(src: &mut std::slice::Iter<'a, ValRaw>) -> &'a ValRaw { src.next().unwrap() }