From ce7bbef24df39ee7060bf786a7c95888d291b53a Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Thu, 28 Jul 2022 11:47:15 -0500 Subject: [PATCH] Implement other variant-like types in adapter fusion (#4547) This commit fills out the adapter fusion compiler for the `union`, `enum`, `option,` and `result` types. The preexisting support for `variant` types was refactored slightly to be extensible to all of these other types and they all now call into the same common translation code. --- .../fuzz/fuzz_targets/fact-valid-module.rs | 30 +++ crates/environ/src/component/types.rs | 46 ++-- crates/environ/src/fact/trampoline.rs | 218 ++++++++++++++++-- 3 files changed, 258 insertions(+), 36 deletions(-) diff --git a/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs b/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs index f00b27b6a1..ba5c820ea9 100644 --- a/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs +++ b/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs @@ -16,6 +16,9 @@ use wasmparser::{Validator, WasmFeatures}; use wasmtime_environ::component::*; use wasmtime_environ::fact::Module; +// Allow inflating to 16 bits but don't go further. +const MAX_ENUM_SIZE: usize = 257; + #[derive(Arbitrary, Debug)] struct GenAdapterModule { debug: bool, @@ -55,6 +58,10 @@ enum ValType { Record(Vec), Tuple(Vec), Variant(NonZeroLenVec), + Union(NonZeroLenVec), + Enum(usize), + Option(Box), + Expected(Box, Box), } #[derive(Copy, Clone, Arbitrary, Debug)] @@ -240,6 +247,29 @@ fn intern(types: &mut ComponentTypesBuilder, ty: &ValType) -> InterfaceType { }; InterfaceType::Variant(types.add_variant_type(ty)) } + ValType::Union(tys) => { + let ty = TypeUnion { + types: tys.0.iter().map(|ty| intern(types, ty)).collect(), + }; + InterfaceType::Union(types.add_union_type(ty)) + } + ValType::Enum(size) => { + let size = size % MAX_ENUM_SIZE; + let size = if size == 0 { 1 } else { size }; + let ty = TypeEnum { + names: (0..size).map(|i| format!("c{i}")).collect(), + }; + InterfaceType::Enum(types.add_enum_type(ty)) + } + ValType::Option(ty) => { + let ty = intern(types, ty); + InterfaceType::Option(types.add_interface_type(ty)) + } + ValType::Expected(ok, err) => { + let ok = intern(types, ok); + let err = intern(types, err); + InterfaceType::Expected(types.add_expected_type(TypeExpected { ok, err })) + } } } diff --git a/crates/environ/src/component/types.rs b/crates/environ/src/component/types.rs index 4fec635477..b1b553782e 100644 --- a/crates/environ/src/component/types.rs +++ b/crates/environ/src/component/types.rs @@ -588,7 +588,7 @@ impl ComponentTypesBuilder { } wasmparser::ComponentDefinedType::List(e) => { let ty = self.valtype(e); - InterfaceType::List(self.intern_interface_type(ty)) + InterfaceType::List(self.add_interface_type(ty)) } wasmparser::ComponentDefinedType::Tuple(e) => InterfaceType::Tuple(self.tuple_type(e)), wasmparser::ComponentDefinedType::Flags(e) => InterfaceType::Flags(self.flags_type(e)), @@ -596,7 +596,7 @@ impl ComponentTypesBuilder { wasmparser::ComponentDefinedType::Union(e) => InterfaceType::Union(self.union_type(e)), wasmparser::ComponentDefinedType::Option(e) => { let ty = self.valtype(e); - InterfaceType::Option(self.intern_interface_type(ty)) + InterfaceType::Option(self.add_interface_type(ty)) } wasmparser::ComponentDefinedType::Expected { ok, error } => { InterfaceType::Expected(self.expected_type(ok, error)) @@ -618,14 +618,6 @@ impl ComponentTypesBuilder { } } - fn intern_interface_type(&mut self, ty: InterfaceType) -> TypeInterfaceIndex { - intern( - &mut self.interface_types, - &mut self.component_types.interface_types, - ty, - ) - } - fn record_type(&mut self, record: &[(&str, wasmparser::ComponentValType)]) -> TypeRecordIndex { let record = TypeRecord { fields: record @@ -675,14 +667,14 @@ impl ComponentTypesBuilder { let e = TypeEnum { names: variants.iter().map(|s| s.to_string()).collect(), }; - intern(&mut self.enums, &mut self.component_types.enums, e) + self.add_enum_type(e) } fn union_type(&mut self, types: &[wasmparser::ComponentValType]) -> TypeUnionIndex { let union = TypeUnion { types: types.iter().map(|ty| self.valtype(ty)).collect(), }; - intern(&mut self.unions, &mut self.component_types.unions, union) + self.add_union_type(union) } fn expected_type( @@ -694,11 +686,7 @@ impl ComponentTypesBuilder { ok: self.valtype(ok), err: self.valtype(err), }; - intern( - &mut self.expecteds, - &mut self.component_types.expecteds, - expected, - ) + self.add_expected_type(expected) } /// Interns a new function type within this type information. @@ -720,6 +708,30 @@ impl ComponentTypesBuilder { pub fn add_variant_type(&mut self, ty: TypeVariant) -> TypeVariantIndex { intern(&mut self.variants, &mut self.component_types.variants, ty) } + + /// Interns a new union type within this type information. + pub fn add_union_type(&mut self, ty: TypeUnion) -> TypeUnionIndex { + intern(&mut self.unions, &mut self.component_types.unions, ty) + } + + /// Interns a new enum type within this type information. + pub fn add_enum_type(&mut self, ty: TypeEnum) -> TypeEnumIndex { + intern(&mut self.enums, &mut self.component_types.enums, ty) + } + + /// Interns a new expected type within this type information. + pub fn add_expected_type(&mut self, ty: TypeExpected) -> TypeExpectedIndex { + intern(&mut self.expecteds, &mut self.component_types.expecteds, ty) + } + + /// Interns a new expected type within this type information. + pub fn add_interface_type(&mut self, ty: InterfaceType) -> TypeInterfaceIndex { + intern( + &mut self.interface_types, + &mut self.component_types.interface_types, + ty, + ) + } } // Forward the indexing impl to the internal `TypeTables` diff --git a/crates/environ/src/fact/trampoline.rs b/crates/environ/src/fact/trampoline.rs index 9973f60a65..d4b2f8ed32 100644 --- a/crates/environ/src/fact/trampoline.rs +++ b/crates/environ/src/fact/trampoline.rs @@ -16,8 +16,9 @@ //! can be somewhat arbitrary, an intentional decision. use crate::component::{ - InterfaceType, TypeRecordIndex, TypeTupleIndex, TypeVariantIndex, FLAG_MAY_ENTER, - FLAG_MAY_LEAVE, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, + InterfaceType, TypeEnumIndex, TypeExpectedIndex, TypeInterfaceIndex, TypeRecordIndex, + TypeTupleIndex, TypeUnionIndex, TypeVariantIndex, FLAG_MAY_ENTER, FLAG_MAY_LEAVE, + MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, }; use crate::fact::core_types::CoreTypes; use crate::fact::signature::{align_to, Signature}; @@ -350,6 +351,10 @@ impl Compiler<'_, '_> { InterfaceType::Record(t) => self.translate_record(*t, src, dst_ty, dst), InterfaceType::Tuple(t) => self.translate_tuple(*t, src, dst_ty, dst), InterfaceType::Variant(v) => self.translate_variant(*v, src, dst_ty, dst), + InterfaceType::Union(u) => self.translate_union(*u, src, dst_ty, dst), + InterfaceType::Enum(t) => self.translate_enum(*t, src, dst_ty, dst), + InterfaceType::Option(t) => self.translate_option(*t, src, dst_ty, dst), + InterfaceType::Expected(t) => self.translate_expected(*t, src, dst_ty, dst), InterfaceType::String => { // consider this field used for now until this is fully @@ -686,6 +691,174 @@ impl Compiler<'_, '_> { let src_disc_size = DiscriminantSize::from_count(src_ty.cases.len()).unwrap(); let dst_disc_size = DiscriminantSize::from_count(dst_ty.cases.len()).unwrap(); + let iter = src_ty.cases.iter().enumerate().map(|(src_i, src_case)| { + let dst_i = dst_ty + .cases + .iter() + .position(|c| c.name == src_case.name) + .unwrap(); + let dst_case = &dst_ty.cases[dst_i]; + let src_i = u32::try_from(src_i).unwrap(); + let dst_i = u32::try_from(dst_i).unwrap(); + VariantCase { + src_i, + src_ty: &src_case.ty, + dst_i, + dst_ty: &dst_case.ty, + } + }); + self.convert_variant(src, src_disc_size, dst, dst_disc_size, iter); + } + + fn translate_union( + &mut self, + src_ty: TypeUnionIndex, + src: &Source<'_>, + dst_ty: &InterfaceType, + dst: &Destination, + ) { + let src_ty = &self.module.types[src_ty]; + let dst_ty = match dst_ty { + InterfaceType::Union(t) => &self.module.types[*t], + _ => panic!("expected an option"), + }; + assert_eq!(src_ty.types.len(), dst_ty.types.len()); + + self.convert_variant( + src, + DiscriminantSize::Size1, + dst, + DiscriminantSize::Size1, + src_ty + .types + .iter() + .zip(dst_ty.types.iter()) + .enumerate() + .map(|(i, (src_ty, dst_ty))| { + let i = u32::try_from(i).unwrap(); + VariantCase { + src_i: i, + dst_i: i, + src_ty, + dst_ty, + } + }), + ); + } + + fn translate_enum( + &mut self, + src_ty: TypeEnumIndex, + src: &Source<'_>, + dst_ty: &InterfaceType, + dst: &Destination, + ) { + let src_ty = &self.module.types[src_ty]; + let dst_ty = match dst_ty { + InterfaceType::Enum(t) => &self.module.types[*t], + _ => panic!("expected an option"), + }; + + let unit = &InterfaceType::Unit; + self.convert_variant( + src, + DiscriminantSize::from_count(src_ty.names.len()).unwrap(), + dst, + DiscriminantSize::from_count(dst_ty.names.len()).unwrap(), + src_ty.names.iter().enumerate().map(|(src_i, src_name)| { + let dst_i = dst_ty.names.iter().position(|n| n == src_name).unwrap(); + let src_i = u32::try_from(src_i).unwrap(); + let dst_i = u32::try_from(dst_i).unwrap(); + VariantCase { + src_i, + dst_i, + src_ty: unit, + dst_ty: unit, + } + }), + ); + } + + fn translate_option( + &mut self, + src_ty: TypeInterfaceIndex, + src: &Source<'_>, + dst_ty: &InterfaceType, + dst: &Destination, + ) { + let src_ty = &self.module.types[src_ty]; + let dst_ty = match dst_ty { + InterfaceType::Option(t) => &self.module.types[*t], + _ => panic!("expected an option"), + }; + + self.convert_variant( + src, + DiscriminantSize::Size1, + dst, + DiscriminantSize::Size1, + [ + VariantCase { + src_i: 0, + dst_i: 0, + src_ty: &InterfaceType::Unit, + dst_ty: &InterfaceType::Unit, + }, + VariantCase { + src_i: 1, + dst_i: 1, + src_ty, + dst_ty, + }, + ] + .into_iter(), + ); + } + + fn translate_expected( + &mut self, + src_ty: TypeExpectedIndex, + src: &Source<'_>, + dst_ty: &InterfaceType, + dst: &Destination, + ) { + let src_ty = &self.module.types[src_ty]; + let dst_ty = match dst_ty { + InterfaceType::Expected(t) => &self.module.types[*t], + _ => panic!("expected an expected"), + }; + + self.convert_variant( + src, + DiscriminantSize::Size1, + dst, + DiscriminantSize::Size1, + [ + VariantCase { + src_i: 0, + dst_i: 0, + src_ty: &src_ty.ok, + dst_ty: &dst_ty.ok, + }, + VariantCase { + src_i: 1, + dst_i: 1, + src_ty: &src_ty.err, + dst_ty: &dst_ty.err, + }, + ] + .into_iter(), + ); + } + + fn convert_variant<'a>( + &mut self, + src: &Source<'_>, + src_disc_size: DiscriminantSize, + dst: &Destination, + dst_disc_size: DiscriminantSize, + src_cases: impl ExactSizeIterator>, + ) { // The outermost block is special since it has the result type of the // translation here. That will depend on the `dst`. let outer_block_ty = match dst { @@ -703,7 +876,8 @@ impl Compiler<'_, '_> { // After the outermost block generate a new block for each of the // remaining cases. - for _ in 0..src_ty.cases.len() - 1 { + let src_cases_len = src_cases.len(); + for _ in 0..src_cases_len - 1 { self.instruction(Block(BlockType::Empty)); } @@ -727,7 +901,7 @@ impl Compiler<'_, '_> { // Generate the `br_table` for the discriminant. Each case has an // offset of 1 to skip the trapping block. let mut targets = Vec::new(); - for i in 0..src_ty.cases.len() { + for i in 0..src_cases_len { targets.push((i + 1) as u32); } self.instruction(BrTable(targets[..].into(), 0)); @@ -740,19 +914,19 @@ impl Compiler<'_, '_> { // iteration order here places the first case in the innermost block // and the last case in the outermost block. This matches the order // of the jump targets in the `br_table` instruction. - for (src_i, src_case) in src_ty.cases.iter().enumerate() { - let dst_i = dst_ty - .cases - .iter() - .position(|c| c.name == src_case.name) - .unwrap(); - let dst_case = &dst_ty.cases[dst_i]; - let dst_i = u32::try_from(dst_i).unwrap() as i32; + let src_cases_len = u32::try_from(src_cases_len).unwrap(); + for case in src_cases { + let VariantCase { + src_i, + src_ty, + dst_i, + dst_ty, + } = case; // Translate the discriminant here, noting that `dst_i` may be // different than `src_i`. self.push_dst_addr(dst); - self.instruction(I32Const(dst_i)); + self.instruction(I32Const(dst_i as i32)); match dst { Destination::Stack(stack) => self.stack_set(&stack[..1], ValType::I32), Destination::Memory(mem) => match dst_disc_size { @@ -764,9 +938,9 @@ impl Compiler<'_, '_> { // Translate the payload of this case using the various types from // the dst/src. - let src_payload = src.payload_src(self.module, src_disc_size, &src_case.ty); - let dst_payload = dst.payload_dst(self.module, dst_disc_size, &dst_case.ty); - self.translate(&src_case.ty, &src_payload, &dst_case.ty, &dst_payload); + let src_payload = src.payload_src(self.module, src_disc_size, src_ty); + let dst_payload = dst.payload_dst(self.module, dst_disc_size, dst_ty); + self.translate(src_ty, &src_payload, dst_ty, &dst_payload); // If the results of this translation were placed on the stack then // the stack values may need to be padded with more zeros due to @@ -791,9 +965,8 @@ impl Compiler<'_, '_> { // Branch to the outermost block. Note that this isn't needed for // the outermost case since it simply falls through. - let src_len = src_ty.cases.len(); - if src_i != src_len - 1 { - self.instruction(Br((src_len - src_i - 1) as u32)); + if src_i != src_cases_len - 1 { + self.instruction(Br(src_cases_len - src_i - 1)); } self.instruction(End); // end this case's block } @@ -1264,3 +1437,10 @@ impl<'a> Stack<'a> { } } } + +struct VariantCase<'a> { + src_i: u32, + src_ty: &'a InterfaceType, + dst_i: u32, + dst_ty: &'a InterfaceType, +}