From bd70dbebbd7f2f9f0fb4b61880c0fc7a5ae8fff4 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Tue, 9 Aug 2022 14:52:20 -0500 Subject: [PATCH] Deduplicate some size/align calculations (#4658) This commit is an effort to reduce the amount of complexity around managing the size/alignment calculations of types in the canonical ABI. Previously the logic for the size/alignment of a type was spread out across a number of locations. While each individual calculation is not really the most complicated thing in the world having the duplication in so many places was constantly worrying me. I've opted in this commit to centralize all of this within the runtime at least, and now there's only one "duplicate" of this information in the fuzzing infrastructure which is to some degree less important to deduplicate. This commit introduces a new `CanonicalAbiInfo` type to house all abi size/align information for both memory32 and memory64. This new type is then used pervasively throughout fused adapter compilation, dynamic `Val` management, and typed functions. This type was also able to reduce the complexity of the macro-generated code meaning that even `wasmtime-component-macro` is performing less math than it was before. One other major feature of this commit is that this ABI information is now saved within a `ComponentTypes` structure. This avoids recursive querying of size/align information frequently and instead effectively caching it. This was a worry I had for the fused adapter compiler which frequently sought out size/align information and would recursively descend each type tree each time. The `fact-valid-module` fuzzer is now nearly 10x faster in terms of iterations/s which I suspect is due to this caching. --- Cargo.lock | 1 + crates/component-macro/src/lib.rs | 87 +--- crates/component-util/src/lib.rs | 23 +- crates/environ/fuzz/Cargo.toml | 1 + .../fuzz/fuzz_targets/fact-valid-module.rs | 220 +++----- crates/environ/src/component/types.rs | 493 ++++++++++++++++-- crates/environ/src/fact/signature.rs | 105 +--- crates/environ/src/fact/trampoline.rs | 147 +++--- crates/misc/component-test-util/src/lib.rs | 5 +- crates/wasmtime/src/component/func.rs | 25 +- crates/wasmtime/src/component/func/host.rs | 33 +- crates/wasmtime/src/component/func/typed.rs | 166 +++--- crates/wasmtime/src/component/mod.rs | 6 +- crates/wasmtime/src/component/types.rs | 165 ++---- crates/wasmtime/src/component/values.rs | 131 ++--- 15 files changed, 845 insertions(+), 763 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b952743cea..3d6bfc50d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3559,6 +3559,7 @@ dependencies = [ "wasmparser", "wasmprinter", "wasmtime-environ", + "wat", ] [[package]] diff --git a/crates/component-macro/src/lib.rs b/crates/component-macro/src/lib.rs index 06f11c6021..7809caf7dd 100644 --- a/crates/component-macro/src/lib.rs +++ b/crates/component-macro/src/lib.rs @@ -298,7 +298,7 @@ fn expand_record_for_component_type( let mut lower_generic_params = TokenStream::new(); let mut lower_generic_args = TokenStream::new(); let mut lower_field_declarations = TokenStream::new(); - let mut sizes = TokenStream::new(); + let mut abi_list = TokenStream::new(); let mut unique_types = HashSet::new(); for (index, syn::Field { ident, ty, .. }) in fields.iter().enumerate() { @@ -309,24 +309,13 @@ fn expand_record_for_component_type( lower_field_declarations.extend(quote!(#ident: #generic,)); - sizes.extend(quote!( - size = #internal::align_to(size, <#ty as wasmtime::component::ComponentType>::ALIGN32); - size += <#ty as wasmtime::component::ComponentType>::SIZE32; + abi_list.extend(quote!( + <#ty as wasmtime::component::ComponentType>::ABI, )); unique_types.insert(ty); } - let alignments = unique_types - .into_iter() - .map(|ty| { - let align = quote!(<#ty as wasmtime::component::ComponentType>::ALIGN32); - quote!(if #align > align { - align = #align; - }) - }) - .collect::(); - let generics = add_trait_bounds(generics, parse_quote!(wasmtime::component::ComponentType)); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let lower = format_ident!("Lower{}", name); @@ -358,17 +347,8 @@ fn expand_record_for_component_type( unsafe impl #impl_generics wasmtime::component::ComponentType for #name #ty_generics #where_clause { type Lower = #lower <#lower_generic_args>; - const SIZE32: usize = { - let mut size = 0; - #sizes - #internal::align_to(size, Self::ALIGN32) - }; - - const ALIGN32: u32 = { - let mut align = 1; - #alignments - align - }; + const ABI: #internal::CanonicalAbiInfo = + #internal::CanonicalAbiInfo::record_static(&[#abi_list]); #[inline] fn typecheck( @@ -429,7 +409,7 @@ impl Expander for LiftExpander { loads.extend(quote!(#ident: <#ty as wasmtime::component::Lift>::load( memory, &bytes - [#internal::next_field::<#ty>(&mut offset)..] + [<#ty as wasmtime::component::ComponentType>::ABI.next_field32_size(&mut offset)..] [..<#ty as wasmtime::component::ComponentType>::SIZE32] )?,)); } @@ -514,8 +494,6 @@ impl Expander for LiftExpander { DiscriminantSize::Size4 => quote!(u32::from_le_bytes(bytes[0..4].try_into()?)), }; - let payload_offset = usize::from(discriminant_size); - let expanded = quote! { unsafe impl #impl_generics wasmtime::component::Lift for #name #ty_generics #where_clause { #[inline] @@ -535,7 +513,8 @@ impl Expander for LiftExpander { let align = ::ALIGN32; debug_assert!((bytes.as_ptr() as usize) % (align as usize) == 0); let discrim = #from_bytes; - let payload = &bytes[#internal::align_to(#payload_offset, align)..]; + let payload_offset = ::PAYLOAD_OFFSET32; + let payload = &bytes[payload_offset..]; Ok(match discrim { #loads discrim => #internal::anyhow::bail!("unexpected discriminant: {}", discrim), @@ -575,7 +554,9 @@ impl Expander for LowerExpander { )?;)); stores.extend(quote!(wasmtime::component::Lower::store( - &self.#ident, memory, #internal::next_field::<#ty>(&mut offset) + &self.#ident, + memory, + <#ty as wasmtime::component::ComponentType>::ABI.next_field32_size(&mut offset), )?;)); } @@ -640,10 +621,7 @@ impl Expander for LowerExpander { lower = quote!(value.lower(store, options, #internal::map_maybe_uninit!(dst.payload.#ident))); store = quote!(value.store( memory, - offset + #internal::align_to( - #discriminant_size, - ::ALIGN32 - ) + offset + ::PAYLOAD_OFFSET32, )); } else { pattern = quote!(Self::#ident); @@ -749,7 +727,7 @@ impl Expander for ComponentTypeExpander { &self, name: &syn::Ident, generics: &syn::Generics, - discriminant_size: DiscriminantSize, + _discriminant_size: DiscriminantSize, cases: &[VariantCase], style: VariantStyle, ) -> Result { @@ -760,7 +738,7 @@ impl Expander for ComponentTypeExpander { let mut lower_payload_generic_args = TokenStream::new(); let mut lower_payload_case_declarations = TokenStream::new(); let mut lower_generic_args = TokenStream::new(); - let mut sizes = TokenStream::new(); + let mut abi_list = TokenStream::new(); let mut unique_types = HashSet::new(); for (index, VariantCase { attrs, ident, ty }) in cases.iter().enumerate() { @@ -776,12 +754,7 @@ impl Expander for ComponentTypeExpander { let name = rename.unwrap_or_else(|| Literal::string(&ident.to_string())); if let Some(ty) = ty { - sizes.extend({ - let size = quote!(<#ty as wasmtime::component::ComponentType>::SIZE32); - quote!(if #size > size { - size = #size; - }) - }); + abi_list.extend(quote!(<#ty as wasmtime::component::ComponentType>::ABI,)); case_names_and_checks.extend(match style { VariantStyle::Variant => { @@ -808,6 +781,7 @@ impl Expander for ComponentTypeExpander { unique_types.insert(ty); } else { + abi_list.extend(quote!(<() as wasmtime::component::ComponentType>::ABI,)); case_names_and_checks.extend(match style { VariantStyle::Variant => { quote!((#name, <() as wasmtime::component::ComponentType>::typecheck),) @@ -824,16 +798,6 @@ impl Expander for ComponentTypeExpander { lower_payload_case_declarations.extend(quote!(_dummy: ())); } - let alignments = unique_types - .into_iter() - .map(|ty| { - let align = quote!(<#ty as wasmtime::component::ComponentType>::ALIGN32); - quote!(if #align > align { - align = #align; - }) - }) - .collect::(); - let typecheck = match style { VariantStyle::Variant => quote!(typecheck_variant), VariantStyle::Union => quote!(typecheck_union), @@ -844,7 +808,6 @@ impl Expander for ComponentTypeExpander { let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let lower = format_ident!("Lower{}", name); let lower_payload = format_ident!("LowerPayload{}", name); - let discriminant_size = u32::from(discriminant_size); // You may wonder why we make the types of all the fields of the #lower struct and #lower_payload union // generic. This is to work around a [normalization bug in @@ -882,20 +845,12 @@ impl Expander for ComponentTypeExpander { #internal::#typecheck(ty, types, &[#case_names_and_checks]) } - const SIZE32: usize = { - let mut size = 0; - #sizes - #internal::align_to( - #internal::align_to(#discriminant_size as usize, Self::ALIGN32) + size, - Self::ALIGN32 - ) - }; + const ABI: #internal::CanonicalAbiInfo = + #internal::CanonicalAbiInfo::variant_static(&[#abi_list]); + } - const ALIGN32: u32 = { - let mut align = #discriminant_size; - #alignments - align - }; + unsafe impl #impl_generics #internal::ComponentVariant for #name #ty_generics #where_clause { + const CASES: &'static [#internal::CanonicalAbiInfo] = &[#abi_list]; } }; diff --git a/crates/component-util/src/lib.rs b/crates/component-util/src/lib.rs index 827b168cb6..3aa77df304 100644 --- a/crates/component-util/src/lib.rs +++ b/crates/component-util/src/lib.rs @@ -1,5 +1,5 @@ /// Represents the possible sizes in bytes of the discriminant of a variant type in the component model -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum DiscriminantSize { /// 8-bit discriminant Size1, @@ -11,7 +11,7 @@ pub enum DiscriminantSize { impl DiscriminantSize { /// Calculate the size of discriminant needed to represent a variant with the specified number of cases. - pub fn from_count(count: usize) -> Option { + pub const fn from_count(count: usize) -> Option { if count <= 0xFF { Some(Self::Size1) } else if count <= 0xFFFF { @@ -22,16 +22,21 @@ impl DiscriminantSize { None } } + + /// Returns the size, in bytes, of this discriminant + pub const fn byte_size(&self) -> u32 { + match self { + DiscriminantSize::Size1 => 1, + DiscriminantSize::Size2 => 2, + DiscriminantSize::Size4 => 4, + } + } } impl From for u32 { /// Size of the discriminant as a `u32` fn from(size: DiscriminantSize) -> u32 { - match size { - DiscriminantSize::Size1 => 1, - DiscriminantSize::Size2 => 2, - DiscriminantSize::Size4 => 4, - } + size.byte_size() } } @@ -60,7 +65,7 @@ pub enum FlagsSize { impl FlagsSize { /// Calculate the size needed to represent a value with the specified number of flags. - pub fn from_count(count: usize) -> FlagsSize { + pub const fn from_count(count: usize) -> FlagsSize { if count == 0 { FlagsSize::Size0 } else if count <= 8 { @@ -74,7 +79,7 @@ impl FlagsSize { } /// Divide `n` by `d`, rounding up in the case of a non-zero remainder. -fn ceiling_divide(n: usize, d: usize) -> usize { +const fn ceiling_divide(n: usize, d: usize) -> usize { (n + d - 1) / d } diff --git a/crates/environ/fuzz/Cargo.toml b/crates/environ/fuzz/Cargo.toml index 4cd5b12151..1793bcc848 100644 --- a/crates/environ/fuzz/Cargo.toml +++ b/crates/environ/fuzz/Cargo.toml @@ -14,6 +14,7 @@ env_logger = "0.9.0" libfuzzer-sys = "0.4" wasmparser = "0.88.0" wasmprinter = "0.2.37" +wat = "1.0" wasmtime-environ = { path = ".." } component-fuzz-util = { path = "../../misc/component-fuzz-util", optional = true } diff --git a/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs b/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs index 5c7114a4a3..2fd52d9004 100644 --- a/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs +++ b/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs @@ -10,9 +10,9 @@ #![no_main] use arbitrary::Arbitrary; -use component_fuzz_util::Type as ValType; +use component_fuzz_util::TestCase; use libfuzzer_sys::fuzz_target; -use wasmparser::{Validator, WasmFeatures}; +use wasmparser::{Parser, Payload, Validator, WasmFeatures}; use wasmtime_environ::component::*; use wasmtime_environ::fact::Module; @@ -24,25 +24,10 @@ struct GenAdapterModule { #[derive(Arbitrary, Debug)] struct GenAdapter { - ty: FuncType, post_return: bool, lift_memory64: bool, lower_memory64: bool, - lift_encoding: GenStringEncoding, - lower_encoding: GenStringEncoding, -} - -#[derive(Arbitrary, Debug)] -struct FuncType { - params: Vec, - result: ValType, -} - -#[derive(Copy, Clone, Arbitrary, Debug)] -enum GenStringEncoding { - Utf8, - Utf16, - CompactUtf16, + test: TestCase, } fuzz_target!(|module: GenAdapterModule| { @@ -90,47 +75,74 @@ fn target(module: GenAdapterModule) { let mut adapters = Vec::new(); for adapter in module.adapters.iter() { - let mut params = Vec::new(); - for param in adapter.ty.params.iter() { - params.push((None, intern(&mut types, param))); + let wat_decls = adapter.test.declarations(); + let wat = format!( + "(component + {types} + (type (func {params} {result})) + )", + types = wat_decls.types, + params = wat_decls.params, + result = wat_decls.result, + ); + let wasm = wat::parse_str(&wat).unwrap(); + + let mut validator = Validator::new_with_features(WasmFeatures { + component_model: true, + ..Default::default() + }); + + types.push_type_scope(); + for payload in Parser::new(0).parse_all(&wasm) { + let payload = payload.unwrap(); + validator.payload(&payload).unwrap(); + let section = match payload { + Payload::ComponentTypeSection(s) => s, + _ => continue, + }; + for ty in section { + let ty = types.intern_component_type(&ty.unwrap()).unwrap(); + types.push_component_typedef(ty); + let ty = match ty { + TypeDef::ComponentFunc(ty) => ty, + _ => continue, + }; + adapters.push(Adapter { + lift_ty: ty, + lower_ty: ty, + lower_options: AdapterOptions { + instance: RuntimeComponentInstanceIndex::from_u32(0), + string_encoding: convert_encoding(adapter.test.encoding1), + memory64: adapter.lower_memory64, + // Pessimistically assume that memory/realloc are going to be + // required for this trampoline and provide it. Avoids doing + // calculations to figure out whether they're necessary and + // simplifies the fuzzer here without reducing coverage within FACT + // itself. + memory: Some(dummy_memory(adapter.lower_memory64)), + realloc: Some(dummy_def()), + // Lowering never allows `post-return` + post_return: None, + }, + lift_options: AdapterOptions { + instance: RuntimeComponentInstanceIndex::from_u32(1), + string_encoding: convert_encoding(adapter.test.encoding2), + memory64: adapter.lift_memory64, + memory: Some(dummy_memory(adapter.lift_memory64)), + realloc: Some(dummy_def()), + post_return: if adapter.post_return { + Some(dummy_def()) + } else { + None + }, + }, + func: dummy_def(), + }); + } } - let result = intern(&mut types, &adapter.ty.result); - let signature = types.add_func_type(TypeFunc { - params: params.into(), - result, - }); - adapters.push(Adapter { - lift_ty: signature, - lower_ty: signature, - lower_options: AdapterOptions { - instance: RuntimeComponentInstanceIndex::from_u32(0), - string_encoding: adapter.lower_encoding.into(), - memory64: adapter.lower_memory64, - // Pessimistically assume that memory/realloc are going to be - // required for this trampoline and provide it. Avoids doing - // calculations to figure out whether they're necessary and - // simplifies the fuzzer here without reducing coverage within FACT - // itself. - memory: Some(dummy_memory(adapter.lower_memory64)), - realloc: Some(dummy_def()), - // Lowering never allows `post-return` - post_return: None, - }, - lift_options: AdapterOptions { - instance: RuntimeComponentInstanceIndex::from_u32(1), - string_encoding: adapter.lift_encoding.into(), - memory64: adapter.lift_memory64, - memory: Some(dummy_memory(adapter.lift_memory64)), - realloc: Some(dummy_def()), - post_return: if adapter.post_return { - Some(dummy_def()) - } else { - None - }, - }, - func: dummy_def(), - }); + types.pop_type_scope(); } + let types = types.finish(); let mut fact_module = Module::new(&types, module.debug); for (i, adapter) in adapters.iter().enumerate() { @@ -161,94 +173,10 @@ fn target(module: GenAdapterModule) { panic!() } -fn intern(types: &mut ComponentTypesBuilder, ty: &ValType) -> InterfaceType { - match ty { - ValType::Unit => InterfaceType::Unit, - ValType::Bool => InterfaceType::Bool, - ValType::U8 => InterfaceType::U8, - ValType::S8 => InterfaceType::S8, - ValType::U16 => InterfaceType::U16, - ValType::S16 => InterfaceType::S16, - ValType::U32 => InterfaceType::U32, - ValType::S32 => InterfaceType::S32, - ValType::U64 => InterfaceType::U64, - ValType::S64 => InterfaceType::S64, - ValType::Float32 => InterfaceType::Float32, - ValType::Float64 => InterfaceType::Float64, - ValType::Char => InterfaceType::Char, - ValType::String => InterfaceType::String, - ValType::List(ty) => { - let ty = intern(types, ty); - InterfaceType::List(types.add_interface_type(ty)) - } - ValType::Record(tys) => { - let ty = TypeRecord { - fields: tys - .iter() - .enumerate() - .map(|(i, ty)| RecordField { - name: format!("f{i}"), - ty: intern(types, ty), - }) - .collect(), - }; - InterfaceType::Record(types.add_record_type(ty)) - } - ValType::Flags(size) => { - let ty = TypeFlags { - names: (0..size.as_usize()).map(|i| format!("f{i}")).collect(), - }; - InterfaceType::Flags(types.add_flags_type(ty)) - } - ValType::Tuple(tys) => { - let ty = TypeTuple { - types: tys.iter().map(|ty| intern(types, ty)).collect(), - }; - InterfaceType::Tuple(types.add_tuple_type(ty)) - } - ValType::Variant(cases) => { - let ty = TypeVariant { - cases: cases - .iter() - .enumerate() - .map(|(i, ty)| VariantCase { - name: format!("c{i}"), - ty: intern(types, ty), - }) - .collect(), - }; - InterfaceType::Variant(types.add_variant_type(ty)) - } - ValType::Union(tys) => { - let ty = TypeUnion { - types: tys.iter().map(|ty| intern(types, ty)).collect(), - }; - InterfaceType::Union(types.add_union_type(ty)) - } - ValType::Enum(size) => { - let ty = TypeEnum { - names: (0..size.as_usize()).map(|i| format!("c{i}")).collect(), - }; - InterfaceType::Enum(types.add_enum_type(ty)) - } - ValType::Option(ty) => { - let ty = intern(types, ty); - InterfaceType::Option(types.add_interface_type(ty)) - } - ValType::Expected { ok, err } => { - let ok = intern(types, ok); - let err = intern(types, err); - InterfaceType::Expected(types.add_expected_type(TypeExpected { ok, err })) - } - } -} - -impl From for StringEncoding { - fn from(gen: GenStringEncoding) -> StringEncoding { - match gen { - GenStringEncoding::Utf8 => StringEncoding::Utf8, - GenStringEncoding::Utf16 => StringEncoding::Utf16, - GenStringEncoding::CompactUtf16 => StringEncoding::CompactUtf16, - } +fn convert_encoding(encoding: component_fuzz_util::StringEncoding) -> StringEncoding { + match encoding { + component_fuzz_util::StringEncoding::Utf8 => StringEncoding::Utf8, + component_fuzz_util::StringEncoding::Utf16 => StringEncoding::Utf16, + component_fuzz_util::StringEncoding::Latin1OrUtf16 => StringEncoding::CompactUtf16, } } diff --git a/crates/environ/src/component/types.rs b/crates/environ/src/component/types.rs index da13b7de5c..5c70f3691f 100644 --- a/crates/environ/src/component/types.rs +++ b/crates/environ/src/component/types.rs @@ -11,6 +11,7 @@ use std::ops::Index; use wasmparser::{ ComponentAlias, ComponentOuterAliasKind, ComponentTypeDeclaration, InstanceTypeDeclaration, }; +use wasmtime_component_util::{DiscriminantSize, FlagsSize}; macro_rules! indices { ($( @@ -89,6 +90,9 @@ indices! { pub struct TypeEnumIndex(u32); /// Index pointing to a union type in the component model. pub struct TypeUnionIndex(u32); + /// Index pointing to an option type in the component model (aka a + /// `Option`) + pub struct TypeOptionIndex(u32); /// Index pointing to an expected type in the component model (aka a /// `Result`) pub struct TypeExpectedIndex(u32); @@ -209,6 +213,7 @@ pub struct ComponentTypes { enums: PrimaryMap, flags: PrimaryMap, unions: PrimaryMap, + options: PrimaryMap, expecteds: PrimaryMap, module_types: ModuleTypes, @@ -219,6 +224,39 @@ impl ComponentTypes { pub fn module_types(&self) -> &ModuleTypes { &self.module_types } + + /// Returns the canonical ABI information about the specified type. + pub fn canonical_abi(&self, ty: &InterfaceType) -> &CanonicalAbiInfo { + match ty { + InterfaceType::Unit => &CanonicalAbiInfo::ZERO, + + InterfaceType::U8 | InterfaceType::S8 | InterfaceType::Bool => { + &CanonicalAbiInfo::SCALAR1 + } + + InterfaceType::U16 | InterfaceType::S16 => &CanonicalAbiInfo::SCALAR2, + + InterfaceType::U32 + | InterfaceType::S32 + | InterfaceType::Float32 + | InterfaceType::Char => &CanonicalAbiInfo::SCALAR4, + + InterfaceType::U64 | InterfaceType::S64 | InterfaceType::Float64 => { + &CanonicalAbiInfo::SCALAR8 + } + + InterfaceType::String | InterfaceType::List(_) => &CanonicalAbiInfo::POINTER_PAIR, + + InterfaceType::Record(i) => &self[*i].abi, + InterfaceType::Variant(i) => &self[*i].abi, + InterfaceType::Tuple(i) => &self[*i].abi, + InterfaceType::Flags(i) => &self[*i].abi, + InterfaceType::Enum(i) => &self[*i].abi, + InterfaceType::Union(i) => &self[*i].abi, + InterfaceType::Option(i) => &self[*i].abi, + InterfaceType::Expected(i) => &self[*i].abi, + } + } } macro_rules! impl_index { @@ -244,6 +282,7 @@ impl_index! { impl Index for ComponentTypes { TypeEnum => enums } impl Index for ComponentTypes { TypeFlags => flags } impl Index for ComponentTypes { TypeUnion => unions } + impl Index for ComponentTypes { TypeOption => options } impl Index for ComponentTypes { TypeExpected => expecteds } } @@ -274,6 +313,7 @@ pub struct ComponentTypesBuilder { enums: HashMap, flags: HashMap, unions: HashMap, + options: HashMap, expecteds: HashMap, component_types: ComponentTypes, @@ -599,8 +639,7 @@ impl ComponentTypesBuilder { wasmparser::ComponentDefinedType::Enum(e) => InterfaceType::Enum(self.enum_type(e)), wasmparser::ComponentDefinedType::Union(e) => InterfaceType::Union(self.union_type(e)), wasmparser::ComponentDefinedType::Option(e) => { - let ty = self.valtype(e); - InterfaceType::Option(self.add_interface_type(ty)) + InterfaceType::Option(self.option_type(e)) } wasmparser::ComponentDefinedType::Expected { ok, error } => { InterfaceType::Expected(self.expected_type(ok, error)) @@ -623,62 +662,90 @@ impl ComponentTypesBuilder { } fn record_type(&mut self, record: &[(&str, wasmparser::ComponentValType)]) -> TypeRecordIndex { - let record = TypeRecord { - fields: record + let fields = record + .iter() + .map(|(name, ty)| RecordField { + name: name.to_string(), + ty: self.valtype(ty), + }) + .collect::>(); + let abi = CanonicalAbiInfo::record( + fields .iter() - .map(|(name, ty)| RecordField { - name: name.to_string(), - ty: self.valtype(ty), - }) - .collect(), - }; - self.add_record_type(record) + .map(|field| self.component_types.canonical_abi(&field.ty)), + ); + self.add_record_type(TypeRecord { fields, abi }) } fn variant_type(&mut self, cases: &[wasmparser::VariantCase<'_>]) -> TypeVariantIndex { - let variant = TypeVariant { - cases: cases + let cases = cases + .iter() + .map(|case| { + // FIXME: need to implement `refines`, not sure what that + // is at this time. + assert!(case.refines.is_none()); + VariantCase { + name: case.name.to_string(), + ty: self.valtype(&case.ty), + } + }) + .collect::>(); + let (info, abi) = VariantInfo::new( + cases .iter() - .map(|case| { - // FIXME: need to implement `refines`, not sure what that - // is at this time. - assert!(case.refines.is_none()); - VariantCase { - name: case.name.to_string(), - ty: self.valtype(&case.ty), - } - }) - .collect(), - }; - self.add_variant_type(variant) + .map(|c| self.component_types.canonical_abi(&c.ty)), + ); + self.add_variant_type(TypeVariant { cases, abi, info }) } fn tuple_type(&mut self, types: &[wasmparser::ComponentValType]) -> TypeTupleIndex { - let tuple = TypeTuple { - types: types.iter().map(|ty| self.valtype(ty)).collect(), - }; - self.add_tuple_type(tuple) + let types = types + .iter() + .map(|ty| self.valtype(ty)) + .collect::>(); + let abi = CanonicalAbiInfo::record( + types + .iter() + .map(|ty| self.component_types.canonical_abi(ty)), + ); + self.add_tuple_type(TypeTuple { types, abi }) } fn flags_type(&mut self, flags: &[&str]) -> TypeFlagsIndex { let flags = TypeFlags { names: flags.iter().map(|s| s.to_string()).collect(), + abi: CanonicalAbiInfo::flags(flags.len()), }; self.add_flags_type(flags) } fn enum_type(&mut self, variants: &[&str]) -> TypeEnumIndex { - let e = TypeEnum { - names: variants.iter().map(|s| s.to_string()).collect(), - }; - self.add_enum_type(e) + let names = variants.iter().map(|s| s.to_string()).collect::>(); + let (info, abi) = VariantInfo::new( + names + .iter() + .map(|_| self.component_types.canonical_abi(&InterfaceType::Unit)), + ); + self.add_enum_type(TypeEnum { names, abi, info }) } fn union_type(&mut self, types: &[wasmparser::ComponentValType]) -> TypeUnionIndex { - let union = TypeUnion { - types: types.iter().map(|ty| self.valtype(ty)).collect(), - }; - self.add_union_type(union) + let types = types + .iter() + .map(|ty| self.valtype(ty)) + .collect::>(); + let (info, abi) = + VariantInfo::new(types.iter().map(|t| self.component_types.canonical_abi(t))); + self.add_union_type(TypeUnion { types, abi, info }) + } + + fn option_type(&mut self, ty: &wasmparser::ComponentValType) -> TypeOptionIndex { + let ty = self.valtype(ty); + let (info, abi) = VariantInfo::new([ + self.component_types.canonical_abi(&InterfaceType::Unit), + self.component_types.canonical_abi(&ty), + ]); + self.add_option_type(TypeOption { ty, abi, info }) } fn expected_type( @@ -686,11 +753,13 @@ impl ComponentTypesBuilder { ok: &wasmparser::ComponentValType, err: &wasmparser::ComponentValType, ) -> TypeExpectedIndex { - let expected = TypeExpected { - ok: self.valtype(ok), - err: self.valtype(err), - }; - self.add_expected_type(expected) + let ok = self.valtype(ok); + let err = self.valtype(err); + let (info, abi) = VariantInfo::new([ + self.component_types.canonical_abi(&ok), + self.component_types.canonical_abi(&err), + ]); + self.add_expected_type(TypeExpected { ok, err, abi, info }) } /// Interns a new function type within this type information. @@ -728,6 +797,11 @@ impl ComponentTypesBuilder { intern(&mut self.enums, &mut self.component_types.enums, ty) } + /// Interns a new option type within this type information. + pub fn add_option_type(&mut self, ty: TypeOption) -> TypeOptionIndex { + intern(&mut self.options, &mut self.component_types.options, ty) + } + /// Interns a new expected type within this type information. pub fn add_expected_type(&mut self, ty: TypeExpected) -> TypeExpectedIndex { intern(&mut self.expecteds, &mut self.component_types.expecteds, ty) @@ -875,7 +949,7 @@ pub enum InterfaceType { Flags(TypeFlagsIndex), Enum(TypeEnumIndex), Union(TypeUnionIndex), - Option(TypeInterfaceIndex), + Option(TypeOptionIndex), Expected(TypeExpectedIndex), } @@ -900,6 +974,306 @@ impl From<&wasmparser::PrimitiveValType> for InterfaceType { } } +/// Bye information about a type in the canonical ABI, with metadata for both +/// memory32 and memory64-based types. +#[derive(Serialize, Deserialize, Clone, Hash, Eq, PartialEq, Debug)] +pub struct CanonicalAbiInfo { + /// The byte-size of this type in a 32-bit memory. + pub size32: u32, + /// The byte-alignment of this type in a 32-bit memory. + pub align32: u32, + /// The byte-size of this type in a 64-bit memory. + pub size64: u32, + /// The byte-alignment of this type in a 64-bit memory. + pub align64: u32, +} + +impl Default for CanonicalAbiInfo { + fn default() -> CanonicalAbiInfo { + CanonicalAbiInfo { + size32: 0, + align32: 1, + size64: 0, + align64: 1, + } + } +} + +const fn align_to(a: u32, b: u32) -> u32 { + assert!(b.is_power_of_two()); + (a + (b - 1)) & !(b - 1) +} + +const fn max(a: u32, b: u32) -> u32 { + if a > b { + a + } else { + b + } +} + +impl CanonicalAbiInfo { + /// ABI information for zero-sized types. + pub const ZERO: CanonicalAbiInfo = CanonicalAbiInfo { + size32: 0, + align32: 1, + size64: 0, + align64: 1, + }; + + /// ABI information for one-byte scalars. + pub const SCALAR1: CanonicalAbiInfo = CanonicalAbiInfo::scalar(1); + /// ABI information for two-byte scalars. + pub const SCALAR2: CanonicalAbiInfo = CanonicalAbiInfo::scalar(2); + /// ABI information for four-byte scalars. + pub const SCALAR4: CanonicalAbiInfo = CanonicalAbiInfo::scalar(4); + /// ABI information for eight-byte scalars. + pub const SCALAR8: CanonicalAbiInfo = CanonicalAbiInfo::scalar(8); + + const fn scalar(size: u32) -> CanonicalAbiInfo { + CanonicalAbiInfo { + size32: size, + align32: size, + size64: size, + align64: size, + } + } + + /// ABI information for lists/strings which are "pointer pairs" + pub const POINTER_PAIR: CanonicalAbiInfo = CanonicalAbiInfo { + size32: 8, + align32: 4, + size64: 16, + align64: 8, + }; + + /// Returns the abi for a record represented by the specified fields. + pub fn record<'a>(fields: impl Iterator) -> CanonicalAbiInfo { + // NB: this is basically a duplicate copy of + // `CanonicalAbiInfo::record_static` and the two should be kept in sync. + + let mut ret = CanonicalAbiInfo::default(); + for field in fields { + ret.size32 = align_to(ret.size32, field.align32) + field.size32; + ret.align32 = ret.align32.max(field.align32); + ret.size64 = align_to(ret.size64, field.align64) + field.size64; + ret.align64 = ret.align64.max(field.align64); + } + ret.size32 = align_to(ret.size32, ret.align32); + ret.size64 = align_to(ret.size64, ret.align64); + return ret; + } + + /// Same as `CanonicalAbiInfo::record` but in a `const`-friendly context. + pub const fn record_static(fields: &[CanonicalAbiInfo]) -> CanonicalAbiInfo { + // NB: this is basically a duplicate copy of `CanonicalAbiInfo::record` + // and the two should be kept in sync. + + let mut ret = CanonicalAbiInfo::ZERO; + let mut i = 0; + while i < fields.len() { + let field = &fields[i]; + ret.size32 = align_to(ret.size32, field.align32) + field.size32; + ret.align32 = max(ret.align32, field.align32); + ret.size64 = align_to(ret.size64, field.align64) + field.size64; + ret.align64 = max(ret.align64, field.align64); + i += 1; + } + ret.size32 = align_to(ret.size32, ret.align32); + ret.size64 = align_to(ret.size64, ret.align64); + return ret; + } + + /// Returns the delta from the current value of `offset` to align properly + /// and read the next record field of type `abi` for 32-bit memories. + pub fn next_field32(&self, offset: &mut u32) -> u32 { + *offset = align_to(*offset, self.align32) + self.size32; + *offset - self.size32 + } + + /// Same as `next_field32`, but bumps a usize pointer + pub fn next_field32_size(&self, offset: &mut usize) -> usize { + let cur = u32::try_from(*offset).unwrap(); + let cur = align_to(cur, self.align32) + self.size32; + *offset = usize::try_from(cur).unwrap(); + usize::try_from(cur - self.size32).unwrap() + } + + /// Returns the delta from the current value of `offset` to align properly + /// and read the next record field of type `abi` for 64-bit memories. + pub fn next_field64(&self, offset: &mut u32) -> u32 { + *offset = align_to(*offset, self.align64) + self.size64; + *offset - self.size64 + } + + /// Same as `next_field64`, but bumps a usize pointer + pub fn next_field64_size(&self, offset: &mut usize) -> usize { + let cur = u32::try_from(*offset).unwrap(); + let cur = align_to(cur, self.align64) + self.size64; + *offset = usize::try_from(cur).unwrap(); + usize::try_from(cur - self.size64).unwrap() + } + + /// Returns ABI information for a structure which contains `count` flags. + pub const fn flags(count: usize) -> CanonicalAbiInfo { + let (size, align) = match FlagsSize::from_count(count) { + FlagsSize::Size0 => (0, 1), + FlagsSize::Size1 => (1, 1), + FlagsSize::Size2 => (2, 2), + FlagsSize::Size4Plus(n) => ((n as u32) * 4, 4), + }; + CanonicalAbiInfo { + size32: size, + align32: align, + size64: size, + align64: align, + } + } + + fn variant<'a, I>(cases: I) -> CanonicalAbiInfo + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + // NB: this is basically a duplicate definition of + // `CanonicalAbiInfo::variant_static`, these should be kept in sync. + + let cases = cases.into_iter(); + let discrim_size = u32::from(DiscriminantSize::from_count(cases.len()).unwrap()); + let mut max_size32 = 0; + let mut max_align32 = discrim_size; + let mut max_size64 = 0; + let mut max_align64 = discrim_size; + for case in cases { + max_size32 = max_size32.max(case.size32); + max_align32 = max_align32.max(case.align32); + max_size64 = max_size64.max(case.size64); + max_align64 = max_align64.max(case.align64); + } + CanonicalAbiInfo { + size32: align_to( + align_to(discrim_size, max_align32) + max_size32, + max_align32, + ), + align32: max_align32, + size64: align_to( + align_to(discrim_size, max_align64) + max_size64, + max_align64, + ), + align64: max_align64, + } + } + + /// Same as `CanonicalAbiInfo::variant` but `const`-safe + pub const fn variant_static(cases: &[CanonicalAbiInfo]) -> CanonicalAbiInfo { + // NB: this is basically a duplicate definition of + // `CanonicalAbiInfo::variant`, these should be kept in sync. + + let discrim_size = match DiscriminantSize::from_count(cases.len()) { + Some(size) => size.byte_size(), + None => unreachable!(), + }; + let mut max_size32 = 0; + let mut max_align32 = discrim_size; + let mut max_size64 = 0; + let mut max_align64 = discrim_size; + let mut i = 0; + while i < cases.len() { + let case = &cases[i]; + max_size32 = max(max_size32, case.size32); + max_align32 = max(max_align32, case.align32); + max_size64 = max(max_size64, case.size64); + max_align64 = max(max_align64, case.align64); + i += 1; + } + CanonicalAbiInfo { + size32: align_to( + align_to(discrim_size, max_align32) + max_size32, + max_align32, + ), + align32: max_align32, + size64: align_to( + align_to(discrim_size, max_align64) + max_size64, + max_align64, + ), + align64: max_align64, + } + } +} + +/// ABI information about the representation of a variant. +#[derive(Serialize, Deserialize, Clone, Hash, Eq, PartialEq, Debug)] +pub struct VariantInfo { + /// The size of the discriminant used. + #[serde(with = "serde_discrim_size")] + pub size: DiscriminantSize, + /// The offset of the payload from the start of the variant in 32-bit + /// memories. + pub payload_offset32: u32, + /// The offset of the payload from the start of the variant in 64-bit + /// memories. + pub payload_offset64: u32, +} + +impl VariantInfo { + /// Returns the abi information for a variant represented by the specified + /// cases. + pub fn new<'a, I>(cases: I) -> (VariantInfo, CanonicalAbiInfo) + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let cases = cases.into_iter(); + let size = DiscriminantSize::from_count(cases.len()).unwrap(); + let abi = CanonicalAbiInfo::variant(cases); + ( + VariantInfo { + size, + payload_offset32: align_to(u32::from(size), abi.align32), + payload_offset64: align_to(u32::from(size), abi.align64), + }, + abi, + ) + } + /// TODO + pub const fn new_static(cases: &[CanonicalAbiInfo]) -> VariantInfo { + let size = match DiscriminantSize::from_count(cases.len()) { + Some(size) => size, + None => unreachable!(), + }; + let abi = CanonicalAbiInfo::variant_static(cases); + VariantInfo { + size, + payload_offset32: align_to(size.byte_size(), abi.align32), + payload_offset64: align_to(size.byte_size(), abi.align64), + } + } +} + +mod serde_discrim_size { + use super::DiscriminantSize; + use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer}; + + pub fn serialize(disc: &DiscriminantSize, ser: S) -> Result + where + S: Serializer, + { + u32::from(*disc).serialize(ser) + } + + pub fn deserialize<'de, D>(deser: D) -> Result + where + D: Deserializer<'de>, + { + match u32::deserialize(deser)? { + 1 => Ok(DiscriminantSize::Size1), + 2 => Ok(DiscriminantSize::Size2), + 4 => Ok(DiscriminantSize::Size4), + _ => Err(D::Error::custom("invalid discriminant size")), + } + } +} + /// Shape of a "record" type in interface types. /// /// This is equivalent to a `struct` in Rust. @@ -907,6 +1281,8 @@ impl From<&wasmparser::PrimitiveValType> for InterfaceType { pub struct TypeRecord { /// The fields that are contained within this struct type. pub fields: Box<[RecordField]>, + /// Byte information about this type in the canonical ABI. + pub abi: CanonicalAbiInfo, } /// One field within a record. @@ -927,6 +1303,10 @@ pub struct RecordField { pub struct TypeVariant { /// The list of cases that this variant can take. pub cases: Box<[VariantCase]>, + /// Byte information about this type in the canonical ABI. + pub abi: CanonicalAbiInfo, + /// Byte information about this variant type. + pub info: VariantInfo, } /// One case of a `variant` type which contains the name of the variant as well @@ -947,6 +1327,8 @@ pub struct VariantCase { pub struct TypeTuple { /// The types that are contained within this tuple. pub types: Box<[InterfaceType]>, + /// Byte information about this type in the canonical ABI. + pub abi: CanonicalAbiInfo, } /// Shape of a "flags" type in interface types. @@ -957,6 +1339,8 @@ pub struct TypeTuple { pub struct TypeFlags { /// The names of all flags, all of which are unique. pub names: Box<[String]>, + /// Byte information about this type in the canonical ABI. + pub abi: CanonicalAbiInfo, } /// Shape of an "enum" type in interface types, not to be confused with a Rust @@ -968,6 +1352,10 @@ pub struct TypeFlags { pub struct TypeEnum { /// The names of this enum, all of which are unique. pub names: Box<[String]>, + /// Byte information about this type in the canonical ABI. + pub abi: CanonicalAbiInfo, + /// Byte information about this variant type. + pub info: VariantInfo, } /// Shape of a "union" type in interface types. @@ -979,6 +1367,21 @@ pub struct TypeEnum { pub struct TypeUnion { /// The list of types this is a union over. pub types: Box<[InterfaceType]>, + /// Byte information about this type in the canonical ABI. + pub abi: CanonicalAbiInfo, + /// Byte information about this variant type. + pub info: VariantInfo, +} + +/// Shape of an "option" interface type. +#[derive(Serialize, Deserialize, Clone, Hash, Eq, PartialEq, Debug)] +pub struct TypeOption { + /// The `T` in `Result` + pub ty: InterfaceType, + /// Byte information about this type in the canonical ABI. + pub abi: CanonicalAbiInfo, + /// Byte information about this variant type. + pub info: VariantInfo, } /// Shape of an "expected" interface type. @@ -988,4 +1391,8 @@ pub struct TypeExpected { pub ok: InterfaceType, /// The `E` in `Result` pub err: InterfaceType, + /// Byte information about this type in the canonical ABI. + pub abi: CanonicalAbiInfo, + /// Byte information about this variant type. + pub info: VariantInfo, } diff --git a/crates/environ/src/fact/signature.rs b/crates/environ/src/fact/signature.rs index f6b8b3fb73..7440831b14 100644 --- a/crates/environ/src/fact/signature.rs +++ b/crates/environ/src/fact/signature.rs @@ -3,7 +3,7 @@ use crate::component::{ComponentTypes, InterfaceType, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS}; use crate::fact::{AdapterOptions, Context, Options}; use wasm_encoder::ValType; -use wasmtime_component_util::{DiscriminantSize, FlagsSize}; +use wasmtime_component_util::FlagsSize; /// Metadata about a core wasm signature which is created for a component model /// signature. @@ -23,11 +23,6 @@ pub struct Signature { pub results_indirect: bool, } -pub(crate) fn align_to(n: usize, align: usize) -> usize { - assert!(align.is_power_of_two()); - (n + (align - 1)) & !(align - 1) -} - impl ComponentTypes { /// Calculates the core wasm function signature for the component function /// type specified within `Context`. @@ -120,15 +115,18 @@ impl ComponentTypes { } InterfaceType::Flags(f) => { let flags = &self[*f]; - let nflags = align_to(flags.names.len(), 32) / 32; - for _ in 0..nflags { - dst.push(ValType::I32); + match FlagsSize::from_count(flags.names.len()) { + FlagsSize::Size0 => {} + FlagsSize::Size1 | FlagsSize::Size2 => dst.push(ValType::I32), + FlagsSize::Size4Plus(n) => { + dst.extend((0..n).map(|_| ValType::I32)); + } } } InterfaceType::Enum(_) => dst.push(ValType::I32), InterfaceType::Option(t) => { dst.push(ValType::I32); - self.push_flat(opts, &self[*t], dst); + self.push_flat(opts, &self[*t].ty, dst); } InterfaceType::Variant(t) => { dst.push(ValType::I32); @@ -185,7 +183,7 @@ impl ComponentTypes { } } - pub(super) fn align(&self, opts: &Options, ty: &InterfaceType) -> usize { + pub(super) fn align(&self, opts: &Options, ty: &InterfaceType) -> u32 { self.size_align(opts, ty).1 } @@ -194,85 +192,12 @@ impl ComponentTypes { // // TODO: this is probably inefficient to entire recalculate at all phases, // seems like it would be best to intern this in some sort of map somewhere. - pub(super) fn size_align(&self, opts: &Options, ty: &InterfaceType) -> (usize, usize) { - match ty { - InterfaceType::Unit => (0, 1), - InterfaceType::Bool | InterfaceType::S8 | InterfaceType::U8 => (1, 1), - InterfaceType::S16 | InterfaceType::U16 => (2, 2), - InterfaceType::S32 - | InterfaceType::U32 - | InterfaceType::Char - | InterfaceType::Float32 => (4, 4), - InterfaceType::S64 | InterfaceType::U64 | InterfaceType::Float64 => (8, 8), - InterfaceType::String | InterfaceType::List(_) => { - ((2 * opts.ptr_size()).into(), opts.ptr_size().into()) - } - - InterfaceType::Record(r) => { - self.record_size_align(opts, self[*r].fields.iter().map(|f| &f.ty)) - } - InterfaceType::Tuple(t) => self.record_size_align(opts, self[*t].types.iter()), - InterfaceType::Flags(f) => match FlagsSize::from_count(self[*f].names.len()) { - FlagsSize::Size0 => (0, 1), - FlagsSize::Size1 => (1, 1), - FlagsSize::Size2 => (2, 2), - FlagsSize::Size4Plus(n) => (n * 4, 4), - }, - InterfaceType::Enum(t) => self.discrim_size_align(self[*t].names.len()), - InterfaceType::Option(t) => { - let ty = &self[*t]; - self.variant_size_align(opts, [&InterfaceType::Unit, ty].into_iter()) - } - InterfaceType::Variant(t) => { - self.variant_size_align(opts, self[*t].cases.iter().map(|c| &c.ty)) - } - InterfaceType::Union(t) => self.variant_size_align(opts, self[*t].types.iter()), - InterfaceType::Expected(t) => { - let e = &self[*t]; - self.variant_size_align(opts, [&e.ok, &e.err].into_iter()) - } - } - } - - pub(super) fn record_size_align<'a>( - &self, - opts: &Options, - fields: impl Iterator, - ) -> (usize, usize) { - let mut size = 0; - let mut align = 1; - for ty in fields { - let (fsize, falign) = self.size_align(opts, ty); - size = align_to(size, falign) + fsize; - align = align.max(falign); - } - (align_to(size, align), align) - } - - fn variant_size_align<'a>( - &self, - opts: &Options, - cases: impl ExactSizeIterator, - ) -> (usize, usize) { - let (discrim_size, mut align) = self.discrim_size_align(cases.len()); - let mut payload_size = 0; - for ty in cases { - let (csize, calign) = self.size_align(opts, ty); - payload_size = payload_size.max(csize); - align = align.max(calign); - } - ( - align_to(align_to(discrim_size, align) + payload_size, align), - align, - ) - } - - fn discrim_size_align<'a>(&self, cases: usize) -> (usize, usize) { - match DiscriminantSize::from_count(cases) { - Some(DiscriminantSize::Size1) => (1, 1), - Some(DiscriminantSize::Size2) => (2, 2), - Some(DiscriminantSize::Size4) => (4, 4), - None => unreachable!(), + pub(super) fn size_align(&self, opts: &Options, ty: &InterfaceType) -> (u32, u32) { + let abi = self.canonical_abi(ty); + if opts.memory64 { + (abi.size64, abi.align64) + } else { + (abi.size32, abi.align32) } } } diff --git a/crates/environ/src/fact/trampoline.rs b/crates/environ/src/fact/trampoline.rs index 6e3fd62c1a..f41699ed76 100644 --- a/crates/environ/src/fact/trampoline.rs +++ b/crates/environ/src/fact/trampoline.rs @@ -16,11 +16,12 @@ //! can be somewhat arbitrary, an intentional decision. use crate::component::{ - ComponentTypes, InterfaceType, StringEncoding, TypeEnumIndex, TypeExpectedIndex, - TypeFlagsIndex, TypeInterfaceIndex, TypeRecordIndex, TypeTupleIndex, TypeUnionIndex, - TypeVariantIndex, FLAG_MAY_ENTER, FLAG_MAY_LEAVE, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, + CanonicalAbiInfo, ComponentTypes, InterfaceType, StringEncoding, TypeEnumIndex, + TypeExpectedIndex, TypeFlagsIndex, TypeInterfaceIndex, TypeOptionIndex, TypeRecordIndex, + TypeTupleIndex, TypeUnionIndex, TypeVariantIndex, VariantInfo, FLAG_MAY_ENTER, FLAG_MAY_LEAVE, + MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, }; -use crate::fact::signature::{align_to, Signature}; +use crate::fact::signature::Signature; use crate::fact::transcode::{FixedEncoding as FE, Transcode, Transcoder}; use crate::fact::traps::Trap; use crate::fact::{AdapterData, Body, Context, Function, FunctionId, Module, Options}; @@ -307,7 +308,12 @@ impl Compiler<'_, '_> { } else { // If there are too many parameters then space is allocated in the // destination module for the parameters via its `realloc` function. - let (size, align) = self.types.record_size_align(lift_opts, dst_tys.iter()); + let abi = CanonicalAbiInfo::record(dst_tys.iter().map(|t| self.types.canonical_abi(t))); + let (size, align) = if lift_opts.memory64 { + (abi.size64, abi.align64) + } else { + (abi.size32, abi.align32) + }; let size = MallocSize::Const(size); Destination::Memory(self.malloc(lift_opts, size, align)) }; @@ -1692,13 +1698,13 @@ impl Compiler<'_, '_> { // Update the two loop pointers if src_size > 0 { self.instruction(LocalGet(cur_src_ptr.idx)); - self.ptr_uconst(src_opts, u32::try_from(src_size).unwrap()); + self.ptr_uconst(src_opts, src_size); self.ptr_add(src_opts); self.instruction(LocalSet(cur_src_ptr.idx)); } if dst_size > 0 { self.instruction(LocalGet(cur_dst_ptr.idx)); - self.ptr_uconst(dst_opts, u32::try_from(dst_size).unwrap()); + self.ptr_uconst(dst_opts, dst_size); self.ptr_add(dst_opts); self.instruction(LocalSet(cur_dst_ptr.idx)); } @@ -1745,7 +1751,7 @@ impl Compiler<'_, '_> { &mut self, opts: &Options, len_local: u32, - elt_size: usize, + elt_size: u32, ) -> TempLocal { // Zero-size types are easy to handle here because the byte size of the // destination is always zero. @@ -1810,7 +1816,7 @@ impl Compiler<'_, '_> { // // The result of the multiplication is saved into a local as well to // get the result afterwards. - self.instruction(I64Const(u32::try_from(elt_size).unwrap().into())); + self.instruction(I64Const(elt_size.into())); self.instruction(I64Mul); let tmp = self.local_tee_new_tmp(ValType::I64); // Branch to success if the upper 32-bits are zero, otherwise @@ -1983,8 +1989,8 @@ impl Compiler<'_, '_> { _ => panic!("expected a variant"), }; - let src_info = VariantInfo::new(self.types, src.opts(), src_ty.cases.iter().map(|c| c.ty)); - let dst_info = VariantInfo::new(self.types, dst.opts(), dst_ty.cases.iter().map(|c| c.ty)); + let src_info = variant_info(self.types, src_ty.cases.iter().map(|c| c.ty)); + let dst_info = variant_info(self.types, dst_ty.cases.iter().map(|c| c.ty)); let iter = src_ty.cases.iter().enumerate().map(|(src_i, src_case)| { let dst_i = dst_ty @@ -2018,8 +2024,8 @@ impl Compiler<'_, '_> { _ => panic!("expected an option"), }; assert_eq!(src_ty.types.len(), dst_ty.types.len()); - let src_info = VariantInfo::new(self.types, src.opts(), src_ty.types.iter().copied()); - let dst_info = VariantInfo::new(self.types, dst.opts(), dst_ty.types.iter().copied()); + let src_info = variant_info(self.types, src_ty.types.iter().copied()); + let dst_info = variant_info(self.types, dst_ty.types.iter().copied()); self.convert_variant( src, @@ -2055,16 +2061,8 @@ impl Compiler<'_, '_> { InterfaceType::Enum(t) => &self.types[*t], _ => panic!("expected an option"), }; - let src_info = VariantInfo::new( - self.types, - src.opts(), - src_ty.names.iter().map(|_| InterfaceType::Unit), - ); - let dst_info = VariantInfo::new( - self.types, - dst.opts(), - dst_ty.names.iter().map(|_| InterfaceType::Unit), - ); + let src_info = variant_info(self.types, src_ty.names.iter().map(|_| InterfaceType::Unit)); + let dst_info = variant_info(self.types, dst_ty.names.iter().map(|_| InterfaceType::Unit)); let unit = &InterfaceType::Unit; self.convert_variant( @@ -2088,19 +2086,19 @@ impl Compiler<'_, '_> { fn translate_option( &mut self, - src_ty: TypeInterfaceIndex, + src_ty: TypeOptionIndex, src: &Source<'_>, dst_ty: &InterfaceType, dst: &Destination, ) { - let src_ty = &self.types[src_ty]; + let src_ty = &self.types[src_ty].ty; let dst_ty = match dst_ty { - InterfaceType::Option(t) => &self.types[*t], + InterfaceType::Option(t) => &self.types[*t].ty, _ => panic!("expected an option"), }; - let src_info = VariantInfo::new(self.types, src.opts(), [InterfaceType::Unit, *src_ty]); - let dst_info = VariantInfo::new(self.types, dst.opts(), [InterfaceType::Unit, *dst_ty]); + let src_info = variant_info(self.types, [InterfaceType::Unit, *src_ty]); + let dst_info = variant_info(self.types, [InterfaceType::Unit, *dst_ty]); self.convert_variant( src, @@ -2138,8 +2136,8 @@ impl Compiler<'_, '_> { _ => panic!("expected an expected"), }; - let src_info = VariantInfo::new(self.types, src.opts(), [src_ty.ok, src_ty.err]); - let dst_info = VariantInfo::new(self.types, dst.opts(), [dst_ty.ok, dst_ty.err]); + let src_info = variant_info(self.types, [src_ty.ok, src_ty.err]); + let dst_info = variant_info(self.types, [dst_ty.ok, dst_ty.err]); self.convert_variant( src, @@ -2316,7 +2314,7 @@ impl Compiler<'_, '_> { self.instruction(GlobalSet(flags_global.as_u32())); } - fn verify_aligned(&mut self, opts: &Options, addr_local: u32, align: usize) { + fn verify_aligned(&mut self, opts: &Options, addr_local: u32, align: u32) { // If the alignment is 1 then everything is trivially aligned and the // check can be omitted. if align == 1 { @@ -2324,7 +2322,7 @@ impl Compiler<'_, '_> { } self.instruction(LocalGet(addr_local)); assert!(align.is_power_of_two()); - self.ptr_uconst(opts, u32::try_from(align - 1).unwrap()); + self.ptr_uconst(opts, align - 1); self.ptr_and(opts); self.ptr_if(opts, BlockType::Empty); self.trap(Trap::UnalignedPointer); @@ -2343,20 +2341,20 @@ impl Compiler<'_, '_> { self.instruction(LocalGet(mem.addr.idx)); self.ptr_uconst(mem.opts, mem.offset); self.ptr_add(mem.opts); - self.ptr_uconst(mem.opts, u32::try_from(align - 1).unwrap()); + self.ptr_uconst(mem.opts, align - 1); self.ptr_and(mem.opts); self.ptr_if(mem.opts, BlockType::Empty); self.trap(Trap::AssertFailed("pointer not aligned")); self.instruction(End); } - fn malloc<'a>(&mut self, opts: &'a Options, size: MallocSize, align: usize) -> Memory<'a> { + fn malloc<'a>(&mut self, opts: &'a Options, size: MallocSize, align: u32) -> Memory<'a> { let realloc = opts.realloc.unwrap(); self.ptr_uconst(opts, 0); self.ptr_uconst(opts, 0); - self.ptr_uconst(opts, u32::try_from(align).unwrap()); + self.ptr_uconst(opts, align); match size { - MallocSize::Const(size) => self.ptr_uconst(opts, u32::try_from(size).unwrap()), + MallocSize::Const(size) => self.ptr_uconst(opts, size), MallocSize::Local(idx) => self.instruction(LocalGet(idx)), } self.instruction(Call(realloc.as_u32())); @@ -2364,12 +2362,7 @@ impl Compiler<'_, '_> { self.memory_operand(opts, addr, align) } - fn memory_operand<'a>( - &mut self, - opts: &'a Options, - addr: TempLocal, - align: usize, - ) -> Memory<'a> { + fn memory_operand<'a>(&mut self, opts: &'a Options, addr: TempLocal, align: u32) -> Memory<'a> { let ret = Memory { addr, offset: 0, @@ -2795,9 +2788,9 @@ impl<'a> Source<'a> { Source::Memory(mem) } Source::Stack(stack) => { - let cnt = types.flatten_types(stack.opts, [ty]).len(); + let cnt = types.flatten_types(stack.opts, [ty]).len() as u32; offset += cnt; - Source::Stack(stack.slice(offset - cnt..offset)) + Source::Stack(stack.slice((offset - cnt) as usize..offset as usize)) } }) } @@ -2815,7 +2808,11 @@ impl<'a> Source<'a> { Source::Stack(s.slice(1..s.locals.len()).slice(0..flat_len)) } Source::Memory(mem) => { - let mem = info.payload_offset(case, mem); + let mem = if mem.opts.memory64 { + mem.bump(info.payload_offset64) + } else { + mem.bump(info.payload_offset32) + }; Source::Memory(mem) } } @@ -2846,9 +2843,9 @@ impl<'a> Destination<'a> { Destination::Memory(mem) } Destination::Stack(s, opts) => { - let cnt = types.flatten_types(opts, [ty]).len(); + let cnt = types.flatten_types(opts, [ty]).len() as u32; offset += cnt; - Destination::Stack(&s[offset - cnt..offset], opts) + Destination::Stack(&s[(offset - cnt) as usize..offset as usize], opts) } }) } @@ -2866,7 +2863,11 @@ impl<'a> Destination<'a> { Destination::Stack(&s[1..][..flat_len], opts) } Destination::Memory(mem) => { - let mem = info.payload_offset(case, mem); + let mem = if mem.opts.memory64 { + mem.bump(info.payload_offset64) + } else { + mem.bump(info.payload_offset32) + }; Destination::Memory(mem) } } @@ -2881,38 +2882,18 @@ impl<'a> Destination<'a> { } fn next_field_offset<'a>( - offset: &mut usize, + offset: &mut u32, types: &ComponentTypes, field: &InterfaceType, mem: &Memory<'a>, ) -> Memory<'a> { - let (size, align) = types.size_align(mem.opts, field); - *offset = align_to(*offset, align) + size; - mem.bump(*offset - size) -} - -struct VariantInfo { - size: DiscriminantSize, - align: usize, -} - -impl VariantInfo { - fn new(types: &ComponentTypes, options: &Options, iter: I) -> VariantInfo - where - I: IntoIterator, - I::IntoIter: ExactSizeIterator, - { - let iter = iter.into_iter(); - let size = DiscriminantSize::from_count(iter.len()).unwrap(); - VariantInfo { - size, - align: usize::from(size).max(iter.map(|i| types.align(options, &i)).max().unwrap_or(1)), - } - } - - fn payload_offset<'a>(&self, _case: &InterfaceType, mem: &Memory<'a>) -> Memory<'a> { - mem.bump(align_to(self.size.into(), self.align)) - } + let abi = types.canonical_abi(field); + let offset = if mem.opts.memory64 { + abi.next_field64(offset) + } else { + abi.next_field32(offset) + }; + mem.bump(offset) } impl<'a> Memory<'a> { @@ -2924,11 +2905,11 @@ impl<'a> Memory<'a> { } } - fn bump(&self, offset: usize) -> Memory<'a> { + fn bump(&self, offset: u32) -> Memory<'a> { Memory { opts: self.opts, addr: TempLocal::new(self.addr.idx, self.addr.ty), - offset: self.offset + u32::try_from(offset).unwrap(), + offset: self.offset + offset, } } } @@ -2949,8 +2930,16 @@ struct VariantCase<'a> { dst_ty: &'a InterfaceType, } +fn variant_info(types: &ComponentTypes, cases: I) -> VariantInfo +where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ + VariantInfo::new(cases.into_iter().map(|i| types.canonical_abi(&i))).0 +} + enum MallocSize { - Const(usize), + Const(u32), Local(u32), } diff --git a/crates/misc/component-test-util/src/lib.rs b/crates/misc/component-test-util/src/lib.rs index eb79e0f969..c5a4970e6f 100644 --- a/crates/misc/component-test-util/src/lib.rs +++ b/crates/misc/component-test-util/src/lib.rs @@ -2,7 +2,7 @@ use anyhow::Result; use arbitrary::Arbitrary; use std::mem::MaybeUninit; use wasmtime::component::__internal::{ - ComponentTypes, InterfaceType, Memory, MemoryMut, Options, StoreOpaque, + CanonicalAbiInfo, ComponentTypes, InterfaceType, Memory, MemoryMut, Options, StoreOpaque, }; use wasmtime::component::{ComponentParams, ComponentType, Func, Lift, Lower, TypedFunc, Val}; use wasmtime::{AsContextMut, Config, Engine, StoreContextMut}; @@ -68,8 +68,7 @@ macro_rules! forward_impls { unsafe impl ComponentType for $a { type Lower = <$b as ComponentType>::Lower; - const SIZE32: usize = <$b as ComponentType>::SIZE32; - const ALIGN32: u32 = <$b as ComponentType>::ALIGN32; + const ABI: CanonicalAbiInfo = <$b as ComponentType>::ABI; #[inline] fn typecheck(ty: &InterfaceType, types: &ComponentTypes) -> Result<()> { diff --git a/crates/wasmtime/src/component/func.rs b/crates/wasmtime/src/component/func.rs index 0942cf7447..28bc5c1601 100644 --- a/crates/wasmtime/src/component/func.rs +++ b/crates/wasmtime/src/component/func.rs @@ -1,5 +1,5 @@ use crate::component::instance::{Instance, InstanceData}; -use crate::component::types::{SizeAndAlignment, Type}; +use crate::component::types::Type; use crate::component::values::Val; use crate::store::{StoreOpaque, Stored}; use crate::{AsContext, AsContextMut, StoreContextMut, ValRaw}; @@ -8,8 +8,8 @@ use std::mem::{self, MaybeUninit}; use std::ptr::NonNull; use std::sync::Arc; use wasmtime_environ::component::{ - CanonicalOptions, ComponentTypes, CoreDef, RuntimeComponentInstanceIndex, TypeFuncIndex, - MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, + CanonicalAbiInfo, CanonicalOptions, ComponentTypes, CoreDef, RuntimeComponentInstanceIndex, + TypeFuncIndex, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, }; use wasmtime_runtime::{Export, ExportFunction, VMTrampoline}; @@ -558,18 +558,15 @@ impl Func { args: &[Val], dst: &mut MaybeUninit<[ValRaw; MAX_FLAT_PARAMS]>, ) -> Result<()> { - let mut size = 0; - let mut alignment = 1; - for ty in params { - alignment = alignment.max(ty.size_and_alignment().alignment); - ty.next_field(&mut size); - } + let abi = CanonicalAbiInfo::record(params.iter().map(|t| t.canonical_abi())); let mut memory = MemoryMut::new(store.as_context_mut(), options); - let ptr = memory.realloc(0, 0, alignment, size)?; + let size = usize::try_from(abi.size32).unwrap(); + let ptr = memory.realloc(0, 0, abi.align32, size)?; let mut offset = ptr; for (ty, arg) in params.iter().zip(args) { - arg.store(&mut memory, ty.next_field(&mut offset))?; + let abi = ty.canonical_abi(); + arg.store(&mut memory, abi.next_field32_size(&mut offset))?; } map_maybe_uninit!(dst[0]).write(ValRaw::i64(ptr as i64)); @@ -582,17 +579,17 @@ impl Func { ty: &Type, src: &mut std::slice::Iter<'_, ValRaw>, ) -> Result { - let SizeAndAlignment { size, alignment } = ty.size_and_alignment(); + let abi = ty.canonical_abi(); // FIXME: needs to read an i64 for memory64 let ptr = usize::try_from(src.next().unwrap().get_u32())?; - if ptr % usize::try_from(alignment)? != 0 { + if ptr % usize::try_from(abi.align32)? != 0 { bail!("return pointer not aligned"); } let bytes = mem .as_slice() .get(ptr..) - .and_then(|b| b.get(..size)) + .and_then(|b| b.get(..usize::try_from(abi.size32).unwrap())) .ok_or_else(|| anyhow::anyhow!("pointer out of bounds of memory"))?; Val::load(ty, mem, bytes) diff --git a/crates/wasmtime/src/component/func/host.rs b/crates/wasmtime/src/component/func/host.rs index b29321439f..6f3c450cb1 100644 --- a/crates/wasmtime/src/component/func/host.rs +++ b/crates/wasmtime/src/component/func/host.rs @@ -1,5 +1,4 @@ use crate::component::func::{Memory, MemoryMut, Options}; -use crate::component::types::SizeAndAlignment; use crate::component::{ComponentParams, ComponentType, Lift, Lower, Type, Val}; use crate::{AsContextMut, StoreContextMut, ValRaw}; use anyhow::{anyhow, bail, Context, Result}; @@ -9,7 +8,8 @@ use std::panic::{self, AssertUnwindSafe}; use std::ptr::NonNull; use std::sync::Arc; use wasmtime_environ::component::{ - ComponentTypes, StringEncoding, TypeFuncIndex, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, + CanonicalAbiInfo, ComponentTypes, StringEncoding, TypeFuncIndex, MAX_FLAT_PARAMS, + MAX_FLAT_RESULTS, }; use wasmtime_runtime::component::{ InstanceFlags, VMComponentContext, VMLowering, VMLoweringCallee, @@ -413,26 +413,19 @@ where .collect::>>()?; ret_index = param_count; } else { - let param_layout = { - let mut size = 0; - let mut alignment = 1; - for ty in params.iter() { - alignment = alignment.max(ty.size_and_alignment().alignment); - ty.next_field(&mut size); - } - SizeAndAlignment { size, alignment } - }; + let param_abi = CanonicalAbiInfo::record(params.iter().map(|t| t.canonical_abi())); let memory = Memory::new(cx.0, &options); - let mut offset = validate_inbounds_dynamic(param_layout, memory.as_slice(), &storage[0])?; + let mut offset = validate_inbounds_dynamic(¶m_abi, memory.as_slice(), &storage[0])?; args = params .iter() .map(|ty| { + let abi = ty.canonical_abi(); + let size = usize::try_from(abi.size32).unwrap(); Val::load( ty, &memory, - &memory.as_slice()[ty.next_field(&mut offset)..] - [..ty.size_and_alignment().size], + &memory.as_slice()[abi.next_field32_size(&mut offset)..][..size], ) }) .collect::>>()?; @@ -451,7 +444,7 @@ where let ret_ptr = &storage[ret_index]; let mut memory = MemoryMut::new(cx.as_context_mut(), &options); let ptr = - validate_inbounds_dynamic(result.size_and_alignment(), memory.as_slice_mut(), ret_ptr)?; + validate_inbounds_dynamic(result.canonical_abi(), memory.as_slice_mut(), ret_ptr)?; ret.store(&mut memory, ptr)?; } @@ -460,17 +453,13 @@ where return Ok(()); } -fn validate_inbounds_dynamic( - SizeAndAlignment { size, alignment }: SizeAndAlignment, - memory: &[u8], - ptr: &ValRaw, -) -> Result { +fn validate_inbounds_dynamic(abi: &CanonicalAbiInfo, memory: &[u8], ptr: &ValRaw) -> Result { // FIXME: needs memory64 support let ptr = usize::try_from(ptr.get_u32())?; - if ptr % usize::try_from(alignment)? != 0 { + if ptr % usize::try_from(abi.align32)? != 0 { bail!("pointer not aligned"); } - let end = match ptr.checked_add(size) { + let end = match ptr.checked_add(usize::try_from(abi.size32).unwrap()) { Some(n) => n, None => bail!("pointer size overflow"), }; diff --git a/crates/wasmtime/src/component/func/typed.rs b/crates/wasmtime/src/component/func/typed.rs index ac4f63a1d0..6a84627834 100644 --- a/crates/wasmtime/src/component/func/typed.rs +++ b/crates/wasmtime/src/component/func/typed.rs @@ -8,7 +8,8 @@ use std::marker; use std::mem::{self, MaybeUninit}; use std::str; use wasmtime_environ::component::{ - ComponentTypes, InterfaceType, StringEncoding, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, + CanonicalAbiInfo, ComponentTypes, InterfaceType, StringEncoding, VariantInfo, MAX_FLAT_PARAMS, + MAX_FLAT_RESULTS, }; /// A statically-typed version of [`Func`] which takes `Params` as input and @@ -363,13 +364,14 @@ pub unsafe trait ComponentType { #[doc(hidden)] type Lower: Copy; - /// The size, in bytes, that this type has in the canonical ABI. + /// The information about this type's canonical ABI (size/align/etc). #[doc(hidden)] - const SIZE32: usize; + const ABI: CanonicalAbiInfo; - /// The alignment, in bytes, that this type has in the canonical ABI. #[doc(hidden)] - const ALIGN32: u32; + const SIZE32: usize = Self::ABI.size32 as usize; + #[doc(hidden)] + const ALIGN32: u32 = Self::ABI.align32; /// Returns the number of core wasm abi values will be used to represent /// this type in its lowered form. @@ -382,14 +384,19 @@ pub unsafe trait ComponentType { mem::size_of::() / mem::size_of::() } - // FIXME: need SIZE64 and ALIGN64 probably - /// Performs a type-check to see whether this component value type matches /// the interface type `ty` provided. #[doc(hidden)] fn typecheck(ty: &InterfaceType, types: &ComponentTypes) -> Result<()>; } +#[doc(hidden)] +pub unsafe trait ComponentVariant: ComponentType { + const CASES: &'static [CanonicalAbiInfo]; + const INFO: VariantInfo = VariantInfo::new_static(Self::CASES); + const PAYLOAD_OFFSET32: usize = Self::INFO.payload_offset32 as usize; +} + /// Host types which can be passed to WebAssembly components. /// /// This trait is implemented for all types that can be passed to components @@ -475,8 +482,7 @@ macro_rules! forward_type_impls { unsafe impl <$($generics)*> ComponentType for $a { type Lower = <$b as ComponentType>::Lower; - const SIZE32: usize = <$b as ComponentType>::SIZE32; - const ALIGN32: u32 = <$b as ComponentType>::ALIGN32; + const ABI: CanonicalAbiInfo = <$b as ComponentType>::ABI; #[inline] fn typecheck(ty: &InterfaceType, types: &ComponentTypes) -> Result<()> { @@ -570,17 +576,11 @@ forward_list_lifts! { // Macro to help generate `ComponentType` implementations for primitive types // such as integers, char, bool, etc. macro_rules! integers { - ($($primitive:ident = $ty:ident in $field:ident/$get:ident,)*) => ($( + ($($primitive:ident = $ty:ident in $field:ident/$get:ident with abi:$abi:ident,)*) => ($( unsafe impl ComponentType for $primitive { type Lower = ValRaw; - const SIZE32: usize = mem::size_of::<$primitive>(); - - // Note that this specifically doesn't use `align_of` as some - // host platforms have a 4-byte alignment for primitive types but - // the canonical abi always has the same size/alignment for these - // types. - const ALIGN32: u32 = mem::size_of::<$primitive>() as u32; + const ABI: CanonicalAbiInfo = CanonicalAbiInfo::$abi; fn typecheck(ty: &InterfaceType, _types: &ComponentTypes) -> Result<()> { match ty { @@ -624,18 +624,18 @@ macro_rules! integers { } integers! { - i8 = S8 in i32/get_i32, - u8 = U8 in u32/get_u32, - i16 = S16 in i32/get_i32, - u16 = U16 in u32/get_u32, - i32 = S32 in i32/get_i32, - u32 = U32 in u32/get_u32, - i64 = S64 in i64/get_i64, - u64 = U64 in u64/get_u64, + i8 = S8 in i32/get_i32 with abi:SCALAR1, + u8 = U8 in u32/get_u32 with abi:SCALAR1, + i16 = S16 in i32/get_i32 with abi:SCALAR2, + u16 = U16 in u32/get_u32 with abi:SCALAR2, + i32 = S32 in i32/get_i32 with abi:SCALAR4, + u32 = U32 in u32/get_u32 with abi:SCALAR4, + i64 = S64 in i64/get_i64 with abi:SCALAR8, + u64 = U64 in u64/get_u64 with abi:SCALAR8, } macro_rules! floats { - ($($float:ident/$get_float:ident = $ty:ident)*) => ($(const _: () = { + ($($float:ident/$get_float:ident = $ty:ident with abi:$abi:ident)*) => ($(const _: () = { /// All floats in-and-out of the canonical abi always have their nan /// payloads canonicalized. conveniently the `NAN` constant in rust has /// the same representation as canonical nan, so we can use that for the @@ -652,11 +652,7 @@ macro_rules! floats { unsafe impl ComponentType for $float { type Lower = ValRaw; - const SIZE32: usize = mem::size_of::<$float>(); - - // note that like integers size is used here instead of alignment to - // respect the canonical abi, not host platforms. - const ALIGN32: u32 = mem::size_of::<$float>() as u32; + const ABI: CanonicalAbiInfo = CanonicalAbiInfo::$abi; fn typecheck(ty: &InterfaceType, _types: &ComponentTypes) -> Result<()> { match ty { @@ -701,15 +697,14 @@ macro_rules! floats { } floats! { - f32/get_f32 = Float32 - f64/get_f64 = Float64 + f32/get_f32 = Float32 with abi:SCALAR4 + f64/get_f64 = Float64 with abi:SCALAR8 } unsafe impl ComponentType for bool { type Lower = ValRaw; - const SIZE32: usize = 1; - const ALIGN32: u32 = 1; + const ABI: CanonicalAbiInfo = CanonicalAbiInfo::SCALAR1; fn typecheck(ty: &InterfaceType, _types: &ComponentTypes) -> Result<()> { match ty { @@ -758,8 +753,7 @@ unsafe impl Lift for bool { unsafe impl ComponentType for char { type Lower = ValRaw; - const SIZE32: usize = 4; - const ALIGN32: u32 = 4; + const ABI: CanonicalAbiInfo = CanonicalAbiInfo::SCALAR4; fn typecheck(ty: &InterfaceType, _types: &ComponentTypes) -> Result<()> { match ty { @@ -810,8 +804,7 @@ const MAX_STRING_BYTE_LENGTH: usize = (1 << 31) - 1; unsafe impl ComponentType for str { type Lower = [ValRaw; 2]; - const SIZE32: usize = 8; - const ALIGN32: u32 = 4; + const ABI: CanonicalAbiInfo = CanonicalAbiInfo::POINTER_PAIR; fn typecheck(ty: &InterfaceType, _types: &ComponentTypes) -> Result<()> { match ty { @@ -1078,8 +1071,7 @@ impl WasmStr { unsafe impl ComponentType for WasmStr { type Lower = ::Lower; - const SIZE32: usize = ::SIZE32; - const ALIGN32: u32 = ::ALIGN32; + const ABI: CanonicalAbiInfo = CanonicalAbiInfo::POINTER_PAIR; fn typecheck(ty: &InterfaceType, _types: &ComponentTypes) -> Result<()> { match ty { @@ -1114,8 +1106,7 @@ where { type Lower = [ValRaw; 2]; - const SIZE32: usize = 8; - const ALIGN32: u32 = 4; + const ABI: CanonicalAbiInfo = CanonicalAbiInfo::POINTER_PAIR; fn typecheck(ty: &InterfaceType, types: &ComponentTypes) -> Result<()> { match ty { @@ -1324,8 +1315,7 @@ raw_wasm_list_accessors! { unsafe impl ComponentType for WasmList { type Lower = <[T] as ComponentType>::Lower; - const SIZE32: usize = <[T] as ComponentType>::SIZE32; - const ALIGN32: u32 = <[T] as ComponentType>::ALIGN32; + const ABI: CanonicalAbiInfo = CanonicalAbiInfo::POINTER_PAIR; fn typecheck(ty: &InterfaceType, types: &ComponentTypes) -> Result<()> { match ty { @@ -1354,24 +1344,6 @@ unsafe impl Lift for WasmList { } } -/// Round `a` up to the next multiple of `align`, assuming that `align` is a power of 2. -#[inline] -pub const fn align_to(a: usize, align: u32) -> usize { - debug_assert!(align.is_power_of_two()); - let align = align as usize; - (a + (align - 1)) & !(align - 1) -} - -/// For a field of type T starting after `offset` bytes, updates the offset to reflect the correct -/// alignment and size of T. Returns the correctly aligned offset for the start of the field. -#[inline] -pub fn next_field(offset: &mut usize) -> usize { - *offset = align_to(*offset, T::ALIGN32); - let result = *offset; - *offset += T::SIZE32; - result -} - /// Verify that the given wasm type is a tuple with the expected fields in the right order. fn typecheck_tuple( ty: &InterfaceType, @@ -1585,17 +1557,24 @@ where { type Lower = TupleLower2<::Lower, T::Lower>; - const SIZE32: usize = align_to(1, T::ALIGN32) + T::SIZE32; - const ALIGN32: u32 = T::ALIGN32; + const ABI: CanonicalAbiInfo = + CanonicalAbiInfo::variant_static(&[<() as ComponentType>::ABI, T::ABI]); fn typecheck(ty: &InterfaceType, types: &ComponentTypes) -> Result<()> { match ty { - InterfaceType::Option(t) => T::typecheck(&types[*t], types), + InterfaceType::Option(t) => T::typecheck(&types[*t].ty, types), other => bail!("expected `option` found `{}`", desc(other)), } } } +unsafe impl ComponentVariant for Option +where + T: ComponentType, +{ + const CASES: &'static [CanonicalAbiInfo] = &[<() as ComponentType>::ABI, T::ABI]; +} + unsafe impl Lower for Option where T: Lower, @@ -1635,7 +1614,7 @@ where } Some(val) => { mem.get::<1>(offset)[0] = 1; - val.store(mem, offset + align_to(1, T::ALIGN32))?; + val.store(mem, offset + (Self::INFO.payload_offset32 as usize))?; } } Ok(()) @@ -1657,7 +1636,7 @@ where fn load(memory: &Memory<'_>, bytes: &[u8]) -> Result { debug_assert!((bytes.as_ptr() as usize) % (Self::ALIGN32 as usize) == 0); let discrim = bytes[0]; - let payload = &bytes[align_to(1, T::ALIGN32)..]; + let payload = &bytes[Self::INFO.payload_offset32 as usize..]; match discrim { 0 => Ok(None), 1 => Ok(Some(T::load(memory, payload)?)), @@ -1687,17 +1666,7 @@ where { type Lower = ResultLower; - const SIZE32: usize = align_to(1, Self::ALIGN32) - + if T::SIZE32 > E::SIZE32 { - T::SIZE32 - } else { - E::SIZE32 - }; - const ALIGN32: u32 = if T::ALIGN32 > E::ALIGN32 { - T::ALIGN32 - } else { - E::ALIGN32 - }; + const ABI: CanonicalAbiInfo = CanonicalAbiInfo::variant_static(&[T::ABI, E::ABI]); fn typecheck(ty: &InterfaceType, types: &ComponentTypes) -> Result<()> { match ty { @@ -1712,6 +1681,14 @@ where } } +unsafe impl ComponentVariant for Result +where + T: ComponentType, + E: ComponentType, +{ + const CASES: &'static [CanonicalAbiInfo] = &[T::ABI, E::ABI]; +} + unsafe impl Lower for Result where T: Lower, @@ -1756,14 +1733,15 @@ where fn store(&self, mem: &mut MemoryMut<'_, U>, offset: usize) -> Result<()> { debug_assert!(offset % (Self::ALIGN32 as usize) == 0); + let payload_offset = Self::INFO.payload_offset32 as usize; match self { Ok(e) => { mem.get::<1>(offset)[0] = 0; - e.store(mem, offset + align_to(1, Self::ALIGN32))?; + e.store(mem, offset + payload_offset)?; } Err(e) => { mem.get::<1>(offset)[0] = 1; - e.store(mem, offset + align_to(1, Self::ALIGN32))?; + e.store(mem, offset + payload_offset)?; } } Ok(()) @@ -1804,9 +1782,8 @@ where fn load(memory: &Memory<'_>, bytes: &[u8]) -> Result { debug_assert!((bytes.as_ptr() as usize) % (Self::ALIGN32 as usize) == 0); - let align = Self::ALIGN32; let discrim = bytes[0]; - let payload = &bytes[align_to(1, align)..]; + let payload = &bytes[Self::INFO.payload_offset32 as usize..]; match discrim { 0 => Ok(Ok(T::load(memory, &payload[..T::SIZE32])?)), 1 => Ok(Err(E::load(memory, &payload[..E::SIZE32])?)), @@ -1832,22 +1809,9 @@ macro_rules! impl_component_ty_for_tuples { { type Lower = []<$($t::Lower),*>; - const SIZE32: usize = { - let mut _size = 0; - $( - _size = align_to(_size, $t::ALIGN32); - _size += $t::SIZE32; - )* - align_to(_size, Self::ALIGN32) - }; - - const ALIGN32: u32 = { - let mut _align = 1; - $(if $t::ALIGN32 > _align { - _align = $t::ALIGN32; - })* - _align - }; + const ABI: CanonicalAbiInfo = CanonicalAbiInfo::record_static(&[ + $($t::ABI),* + ]); fn typecheck( ty: &InterfaceType, @@ -1875,7 +1839,7 @@ macro_rules! impl_component_ty_for_tuples { fn store(&self, _memory: &mut MemoryMut<'_, U>, mut _offset: usize) -> Result<()> { debug_assert!(_offset % (Self::ALIGN32 as usize) == 0); let ($($t,)*) = self; - $($t.store(_memory, next_field::<$t>(&mut _offset))?;)* + $($t.store(_memory, $t::ABI.next_field32_size(&mut _offset))?;)* Ok(()) } } @@ -1891,7 +1855,7 @@ macro_rules! impl_component_ty_for_tuples { fn load(_memory: &Memory<'_>, bytes: &[u8]) -> Result { debug_assert!((bytes.as_ptr() as usize) % (Self::ALIGN32 as usize) == 0); let mut _offset = 0; - $(let $t = $t::load(_memory, &bytes[next_field::<$t>(&mut _offset)..][..$t::SIZE32])?;)* + $(let $t = $t::load(_memory, &bytes[$t::ABI.next_field32_size(&mut _offset)..][..$t::SIZE32])?;)* Ok(($($t,)*)) } } diff --git a/crates/wasmtime/src/component/mod.rs b/crates/wasmtime/src/component/mod.rs index 527409b91e..3b0924dbed 100644 --- a/crates/wasmtime/src/component/mod.rs +++ b/crates/wasmtime/src/component/mod.rs @@ -28,14 +28,14 @@ pub use wasmtime_component_macro::{flags, ComponentType, Lift, Lower}; #[doc(hidden)] pub mod __internal { pub use super::func::{ - align_to, format_flags, next_field, typecheck_enum, typecheck_flags, typecheck_record, - typecheck_union, typecheck_variant, MaybeUninitExt, Memory, MemoryMut, Options, + format_flags, typecheck_enum, typecheck_flags, typecheck_record, typecheck_union, + typecheck_variant, ComponentVariant, MaybeUninitExt, Memory, MemoryMut, Options, }; pub use crate::map_maybe_uninit; pub use crate::store::StoreOpaque; pub use anyhow; pub use wasmtime_environ; - pub use wasmtime_environ::component::{ComponentTypes, InterfaceType}; + pub use wasmtime_environ::component::{CanonicalAbiInfo, ComponentTypes, InterfaceType}; } pub(crate) use self::store::ComponentStoreData; diff --git a/crates/wasmtime/src/component/types.rs b/crates/wasmtime/src/component/types.rs index e87a00cc88..cc065d72fc 100644 --- a/crates/wasmtime/src/component/types.rs +++ b/crates/wasmtime/src/component/types.rs @@ -1,16 +1,15 @@ //! This module defines the `Type` type, representing the dynamic form of a component interface type. -use crate::component::func; use crate::component::values::{self, Val}; use anyhow::{anyhow, Result}; use std::fmt; use std::mem; use std::ops::Deref; use std::sync::Arc; -use wasmtime_component_util::{DiscriminantSize, FlagsSize}; use wasmtime_environ::component::{ - ComponentTypes, InterfaceType, TypeEnumIndex, TypeExpectedIndex, TypeFlagsIndex, - TypeInterfaceIndex, TypeRecordIndex, TypeTupleIndex, TypeUnionIndex, TypeVariantIndex, + CanonicalAbiInfo, ComponentTypes, InterfaceType, TypeEnumIndex, TypeExpectedIndex, + TypeFlagsIndex, TypeInterfaceIndex, TypeOptionIndex, TypeRecordIndex, TypeTupleIndex, + TypeUnionIndex, TypeVariantIndex, VariantInfo, }; #[derive(Clone)] @@ -125,6 +124,10 @@ impl Variant { ty: Type::from(&case.ty, &self.0.types), }) } + + pub(crate) fn variant_info(&self) -> &VariantInfo { + &self.0.types[self.0.index].info + } } /// An `enum` interface type @@ -144,6 +147,10 @@ impl Enum { .iter() .map(|name| name.deref()) } + + pub(crate) fn variant_info(&self) -> &VariantInfo { + &self.0.types[self.0.index].info + } } /// A `union` interface type @@ -163,11 +170,15 @@ impl Union { .iter() .map(|ty| Type::from(ty, &self.0.types)) } + + pub(crate) fn variant_info(&self) -> &VariantInfo { + &self.0.types[self.0.index].info + } } /// An `option` interface type #[derive(Clone, PartialEq, Eq, Debug)] -pub struct Option(Handle); +pub struct Option(Handle); impl Option { /// Instantiate this type with the specified `value`. @@ -177,7 +188,11 @@ impl Option { /// Retrieve the type parameter for this `option`. pub fn ty(&self) -> Type { - Type::from(&self.0.types[self.0.index], &self.0.types) + Type::from(&self.0.types[self.0.index].ty, &self.0.types) + } + + pub(crate) fn variant_info(&self) -> &VariantInfo { + &self.0.types[self.0.index].info } } @@ -200,6 +215,10 @@ impl Expected { pub fn err(&self) -> Type { Type::from(&self.0.types[self.0.index].err, &self.0.types) } + + pub(crate) fn variant_info(&self) -> &VariantInfo { + &self.0.types[self.0.index].info + } } /// A `flags` interface type @@ -221,13 +240,6 @@ impl Flags { } } -/// Represents the size and alignment requirements of the heap-serialized form of a type -#[derive(Debug)] -pub(crate) struct SizeAndAlignment { - pub(crate) size: usize, - pub(crate) alignment: u32, -} - /// Represents a component model interface type #[derive(Clone, PartialEq, Eq, Debug)] pub enum Type { @@ -554,119 +566,22 @@ impl Type { } /// Calculate the size and alignment requirements for the specified type. - pub(crate) fn size_and_alignment(&self) -> SizeAndAlignment { + pub(crate) fn canonical_abi(&self) -> &CanonicalAbiInfo { match self { - Type::Unit => SizeAndAlignment { - size: 0, - alignment: 1, - }, - - Type::Bool | Type::S8 | Type::U8 => SizeAndAlignment { - size: 1, - alignment: 1, - }, - - Type::S16 | Type::U16 => SizeAndAlignment { - size: 2, - alignment: 2, - }, - - Type::S32 | Type::U32 | Type::Char | Type::Float32 => SizeAndAlignment { - size: 4, - alignment: 4, - }, - - Type::S64 | Type::U64 | Type::Float64 => SizeAndAlignment { - size: 8, - alignment: 8, - }, - - Type::String | Type::List(_) => SizeAndAlignment { - size: 8, - alignment: 4, - }, - - Type::Record(handle) => { - record_size_and_alignment(handle.fields().map(|field| field.ty)) - } - - Type::Tuple(handle) => record_size_and_alignment(handle.types()), - - Type::Variant(handle) => variant_size_and_alignment(handle.cases().map(|case| case.ty)), - - Type::Enum(handle) => variant_size_and_alignment(handle.names().map(|_| Type::Unit)), - - Type::Union(handle) => variant_size_and_alignment(handle.types()), - - Type::Option(handle) => { - variant_size_and_alignment([Type::Unit, handle.ty()].into_iter()) - } - - Type::Expected(handle) => { - variant_size_and_alignment([handle.ok(), handle.err()].into_iter()) - } - - Type::Flags(handle) => match FlagsSize::from_count(handle.names().len()) { - FlagsSize::Size0 => SizeAndAlignment { - size: 0, - alignment: 1, - }, - FlagsSize::Size1 => SizeAndAlignment { - size: 1, - alignment: 1, - }, - FlagsSize::Size2 => SizeAndAlignment { - size: 2, - alignment: 2, - }, - FlagsSize::Size4Plus(n) => SizeAndAlignment { - size: n * 4, - alignment: 4, - }, - }, + Type::Unit => &CanonicalAbiInfo::ZERO, + Type::Bool | Type::S8 | Type::U8 => &CanonicalAbiInfo::SCALAR1, + Type::S16 | Type::U16 => &CanonicalAbiInfo::SCALAR2, + Type::S32 | Type::U32 | Type::Char | Type::Float32 => &CanonicalAbiInfo::SCALAR4, + Type::S64 | Type::U64 | Type::Float64 => &CanonicalAbiInfo::SCALAR8, + Type::String | Type::List(_) => &CanonicalAbiInfo::POINTER_PAIR, + Type::Record(handle) => &handle.0.types[handle.0.index].abi, + Type::Tuple(handle) => &handle.0.types[handle.0.index].abi, + Type::Variant(handle) => &handle.0.types[handle.0.index].abi, + Type::Enum(handle) => &handle.0.types[handle.0.index].abi, + Type::Union(handle) => &handle.0.types[handle.0.index].abi, + Type::Option(handle) => &handle.0.types[handle.0.index].abi, + Type::Expected(handle) => &handle.0.types[handle.0.index].abi, + Type::Flags(handle) => &handle.0.types[handle.0.index].abi, } } - - /// Calculate the aligned offset of a field of this type, updating `offset` to point to just after that field. - pub(crate) fn next_field(&self, offset: &mut usize) -> usize { - let SizeAndAlignment { size, alignment } = self.size_and_alignment(); - *offset = func::align_to(*offset, alignment); - let result = *offset; - *offset += size; - result - } -} - -fn record_size_and_alignment(types: impl Iterator) -> SizeAndAlignment { - let mut offset = 0; - let mut align = 1; - for ty in types { - let SizeAndAlignment { size, alignment } = ty.size_and_alignment(); - offset = func::align_to(offset, alignment) + size; - align = align.max(alignment); - } - - SizeAndAlignment { - size: func::align_to(offset, align), - alignment: align, - } -} - -fn variant_size_and_alignment(types: impl ExactSizeIterator) -> SizeAndAlignment { - let discriminant_size = DiscriminantSize::from_count(types.len()).unwrap(); - let mut alignment = u32::from(discriminant_size); - let mut size = 0; - for ty in types { - let size_and_alignment = ty.size_and_alignment(); - alignment = alignment.max(size_and_alignment.alignment); - size = size.max(size_and_alignment.size); - } - - SizeAndAlignment { - size: func::align_to( - func::align_to(usize::from(discriminant_size), alignment) + size, - alignment, - ), - alignment, - } } diff --git a/crates/wasmtime/src/component/values.rs b/crates/wasmtime/src/component/values.rs index 65fd2280b9..b2e2b75cb1 100644 --- a/crates/wasmtime/src/component/values.rs +++ b/crates/wasmtime/src/component/values.rs @@ -1,5 +1,5 @@ -use crate::component::func::{self, Lift, Lower, Memory, MemoryMut, Options}; -use crate::component::types::{self, SizeAndAlignment, Type}; +use crate::component::func::{Lift, Lower, Memory, MemoryMut, Options}; +use crate::component::types::{self, Type}; use crate::store::StoreOpaque; use crate::{AsContextMut, StoreContextMut, ValRaw}; use anyhow::{anyhow, bail, Context, Error, Result}; @@ -9,6 +9,7 @@ use std::iter; use std::mem::MaybeUninit; use std::ops::Deref; use wasmtime_component_util::{DiscriminantSize, FlagsSize}; +use wasmtime_environ::component::VariantInfo; #[derive(PartialEq, Eq, Clone)] pub struct List { @@ -700,8 +701,12 @@ impl Val { values: load_record(handle.types(), mem, bytes)?, }), Type::Variant(handle) => { - let (discriminant, value) = - load_variant(ty, handle.cases().map(|case| case.ty), mem, bytes)?; + let (discriminant, value) = load_variant( + handle.variant_info(), + handle.cases().map(|case| case.ty), + mem, + bytes, + )?; Val::Variant(Variant { ty: handle.clone(), @@ -710,8 +715,12 @@ impl Val { }) } Type::Enum(handle) => { - let (discriminant, _) = - load_variant(ty, handle.names().map(|_| Type::Unit), mem, bytes)?; + let (discriminant, _) = load_variant( + handle.variant_info(), + handle.names().map(|_| Type::Unit), + mem, + bytes, + )?; Val::Enum(Enum { ty: handle.clone(), @@ -719,7 +728,8 @@ impl Val { }) } Type::Union(handle) => { - let (discriminant, value) = load_variant(ty, handle.types(), mem, bytes)?; + let (discriminant, value) = + load_variant(handle.variant_info(), handle.types(), mem, bytes)?; Val::Union(Union { ty: handle.clone(), @@ -728,8 +738,12 @@ impl Val { }) } Type::Option(handle) => { - let (discriminant, value) = - load_variant(ty, [Type::Unit, handle.ty()].into_iter(), mem, bytes)?; + let (discriminant, value) = load_variant( + handle.variant_info(), + [Type::Unit, handle.ty()].into_iter(), + mem, + bytes, + )?; Val::Option(Option { ty: handle.clone(), @@ -738,8 +752,12 @@ impl Val { }) } Type::Expected(handle) => { - let (discriminant, value) = - load_variant(ty, [handle.ok(), handle.err()].into_iter(), mem, bytes)?; + let (discriminant, value) = load_variant( + handle.variant_info(), + [handle.ok(), handle.err()].into_iter(), + mem, + bytes, + )?; Val::Expected(Expected { ty: handle.clone(), @@ -845,7 +863,7 @@ impl Val { /// Serialize this value to the heap at the specified memory location. pub(crate) fn store(&self, mem: &mut MemoryMut<'_, T>, offset: usize) -> Result<()> { - debug_assert!(offset % usize::try_from(self.ty().size_and_alignment().alignment)? == 0); + debug_assert!(offset % usize::try_from(self.ty().canonical_abi().align32)? == 0); match self { Val::Unit => (), @@ -871,35 +889,39 @@ impl Val { Val::Record(Record { values, .. }) | Val::Tuple(Tuple { values, .. }) => { let mut offset = offset; for value in values.deref() { - value.store(mem, value.ty().next_field(&mut offset))?; + value.store( + mem, + value.ty().canonical_abi().next_field32_size(&mut offset), + )?; } } Val::Variant(Variant { discriminant, value, ty, - }) => self.store_variant(*discriminant, value, ty.cases().len(), mem, offset)?, + }) => self.store_variant(*discriminant, value, ty.variant_info(), mem, offset)?, Val::Enum(Enum { discriminant, ty }) => { - self.store_variant(*discriminant, &Val::Unit, ty.names().len(), mem, offset)? + self.store_variant(*discriminant, &Val::Unit, ty.variant_info(), mem, offset)? } Val::Union(Union { discriminant, value, ty, - }) => self.store_variant(*discriminant, value, ty.types().len(), mem, offset)?, + }) => self.store_variant(*discriminant, value, ty.variant_info(), mem, offset)?, Val::Option(Option { discriminant, value, - .. - }) - | Val::Expected(Expected { + ty, + }) => self.store_variant(*discriminant, value, ty.variant_info(), mem, offset)?, + + Val::Expected(Expected { discriminant, value, - .. - }) => self.store_variant(*discriminant, value, 2, mem, offset)?, + ty, + }) => self.store_variant(*discriminant, value, ty.variant_info(), mem, offset)?, Val::Flags(Flags { count, value, .. }) => { match FlagsSize::from_count(*count as usize) { @@ -924,34 +946,26 @@ impl Val { &self, discriminant: u32, value: &Val, - case_count: usize, + info: &VariantInfo, mem: &mut MemoryMut<'_, T>, offset: usize, ) -> Result<()> { - let discriminant_size = DiscriminantSize::from_count(case_count).unwrap(); - match discriminant_size { + match info.size { DiscriminantSize::Size1 => u8::try_from(discriminant).unwrap().store(mem, offset)?, DiscriminantSize::Size2 => u16::try_from(discriminant).unwrap().store(mem, offset)?, - DiscriminantSize::Size4 => (discriminant).store(mem, offset)?, + DiscriminantSize::Size4 => discriminant.store(mem, offset)?, } - value.store( - mem, - offset - + func::align_to( - discriminant_size.into(), - self.ty().size_and_alignment().alignment, - ), - ) + let offset = offset + usize::try_from(info.payload_offset32).unwrap(); + value.store(mem, offset) } } fn load_list(handle: &types::List, mem: &Memory, ptr: usize, len: usize) -> Result { let element_type = handle.ty(); - let SizeAndAlignment { - size: element_size, - alignment: element_alignment, - } = element_type.size_and_alignment(); + let abi = element_type.canonical_abi(); + let element_size = usize::try_from(abi.size32).unwrap(); + let element_alignment = abi.align32; match len .checked_mul(element_size) @@ -986,25 +1000,24 @@ fn load_record( let mut offset = 0; types .map(|ty| { - Val::load( - &ty, - mem, - &bytes[ty.next_field(&mut offset)..][..ty.size_and_alignment().size], - ) + let abi = ty.canonical_abi(); + let offset = abi.next_field32(&mut offset); + let offset = usize::try_from(offset).unwrap(); + let size = usize::try_from(abi.size32).unwrap(); + Val::load(&ty, mem, &bytes[offset..][..size]) }) .collect() } fn load_variant( - ty: &Type, + info: &VariantInfo, mut types: impl ExactSizeIterator, mem: &Memory, bytes: &[u8], ) -> Result<(u32, Val)> { - let discriminant_size = DiscriminantSize::from_count(types.len()).unwrap(); - let discriminant = match discriminant_size { - DiscriminantSize::Size1 => u8::load(mem, &bytes[..1])? as u32, - DiscriminantSize::Size2 => u16::load(mem, &bytes[..2])? as u32, + let discriminant = match info.size { + DiscriminantSize::Size1 => u32::from(u8::load(mem, &bytes[..1])?), + DiscriminantSize::Size2 => u32::from(u16::load(mem, &bytes[..2])?), DiscriminantSize::Size4 => u32::load(mem, &bytes[..4])?, }; let case_ty = types.nth(discriminant as usize).ok_or_else(|| { @@ -1014,14 +1027,9 @@ fn load_variant( types.len() ) })?; - let value = Val::load( - &case_ty, - mem, - &bytes[func::align_to( - usize::from(discriminant_size), - ty.size_and_alignment().alignment, - )..][..case_ty.size_and_alignment().size], - )?; + let payload_offset = usize::try_from(info.payload_offset32).unwrap(); + let case_size = usize::try_from(case_ty.canonical_abi().size32).unwrap(); + let value = Val::load(&case_ty, mem, &bytes[payload_offset..][..case_size])?; Ok((discriminant, value)) } @@ -1050,19 +1058,18 @@ fn lower_list( mem: &mut MemoryMut<'_, T>, items: &[Val], ) -> Result<(usize, usize)> { - let SizeAndAlignment { - size: element_size, - alignment: element_alignment, - } = element_type.size_and_alignment(); + let abi = element_type.canonical_abi(); + let elt_size = usize::try_from(abi.size32)?; + let elt_align = abi.align32; let size = items .len() - .checked_mul(element_size) + .checked_mul(elt_size) .ok_or_else(|| anyhow::anyhow!("size overflow copying a list"))?; - let ptr = mem.realloc(0, 0, element_alignment, size)?; + let ptr = mem.realloc(0, 0, elt_align, size)?; let mut element_ptr = ptr; for item in items { item.store(mem, element_ptr)?; - element_ptr += element_size; + element_ptr += elt_size; } Ok((ptr, items.len())) }