Implement other variant-like types in adapter fusion (#4547)

This commit fills out the adapter fusion compiler for the `union`,
`enum`, `option,` and `result` types. The preexisting support for
`variant` types was refactored slightly to be extensible to all of these
other types and they all now call into the same common translation code.
This commit is contained in:
Alex Crichton
2022-07-28 11:47:15 -05:00
committed by GitHub
parent e1148e43be
commit ce7bbef24d
3 changed files with 258 additions and 36 deletions

View File

@@ -16,6 +16,9 @@ use wasmparser::{Validator, WasmFeatures};
use wasmtime_environ::component::*; use wasmtime_environ::component::*;
use wasmtime_environ::fact::Module; use wasmtime_environ::fact::Module;
// Allow inflating to 16 bits but don't go further.
const MAX_ENUM_SIZE: usize = 257;
#[derive(Arbitrary, Debug)] #[derive(Arbitrary, Debug)]
struct GenAdapterModule { struct GenAdapterModule {
debug: bool, debug: bool,
@@ -55,6 +58,10 @@ enum ValType {
Record(Vec<ValType>), Record(Vec<ValType>),
Tuple(Vec<ValType>), Tuple(Vec<ValType>),
Variant(NonZeroLenVec<ValType>), Variant(NonZeroLenVec<ValType>),
Union(NonZeroLenVec<ValType>),
Enum(usize),
Option(Box<ValType>),
Expected(Box<ValType>, Box<ValType>),
} }
#[derive(Copy, Clone, Arbitrary, Debug)] #[derive(Copy, Clone, Arbitrary, Debug)]
@@ -240,6 +247,29 @@ fn intern(types: &mut ComponentTypesBuilder, ty: &ValType) -> InterfaceType {
}; };
InterfaceType::Variant(types.add_variant_type(ty)) InterfaceType::Variant(types.add_variant_type(ty))
} }
ValType::Union(tys) => {
let ty = TypeUnion {
types: tys.0.iter().map(|ty| intern(types, ty)).collect(),
};
InterfaceType::Union(types.add_union_type(ty))
}
ValType::Enum(size) => {
let size = size % MAX_ENUM_SIZE;
let size = if size == 0 { 1 } else { size };
let ty = TypeEnum {
names: (0..size).map(|i| format!("c{i}")).collect(),
};
InterfaceType::Enum(types.add_enum_type(ty))
}
ValType::Option(ty) => {
let ty = intern(types, ty);
InterfaceType::Option(types.add_interface_type(ty))
}
ValType::Expected(ok, err) => {
let ok = intern(types, ok);
let err = intern(types, err);
InterfaceType::Expected(types.add_expected_type(TypeExpected { ok, err }))
}
} }
} }

View File

@@ -588,7 +588,7 @@ impl ComponentTypesBuilder {
} }
wasmparser::ComponentDefinedType::List(e) => { wasmparser::ComponentDefinedType::List(e) => {
let ty = self.valtype(e); let ty = self.valtype(e);
InterfaceType::List(self.intern_interface_type(ty)) InterfaceType::List(self.add_interface_type(ty))
} }
wasmparser::ComponentDefinedType::Tuple(e) => InterfaceType::Tuple(self.tuple_type(e)), wasmparser::ComponentDefinedType::Tuple(e) => InterfaceType::Tuple(self.tuple_type(e)),
wasmparser::ComponentDefinedType::Flags(e) => InterfaceType::Flags(self.flags_type(e)), wasmparser::ComponentDefinedType::Flags(e) => InterfaceType::Flags(self.flags_type(e)),
@@ -596,7 +596,7 @@ impl ComponentTypesBuilder {
wasmparser::ComponentDefinedType::Union(e) => InterfaceType::Union(self.union_type(e)), wasmparser::ComponentDefinedType::Union(e) => InterfaceType::Union(self.union_type(e)),
wasmparser::ComponentDefinedType::Option(e) => { wasmparser::ComponentDefinedType::Option(e) => {
let ty = self.valtype(e); let ty = self.valtype(e);
InterfaceType::Option(self.intern_interface_type(ty)) InterfaceType::Option(self.add_interface_type(ty))
} }
wasmparser::ComponentDefinedType::Expected { ok, error } => { wasmparser::ComponentDefinedType::Expected { ok, error } => {
InterfaceType::Expected(self.expected_type(ok, error)) InterfaceType::Expected(self.expected_type(ok, error))
@@ -618,14 +618,6 @@ impl ComponentTypesBuilder {
} }
} }
fn intern_interface_type(&mut self, ty: InterfaceType) -> TypeInterfaceIndex {
intern(
&mut self.interface_types,
&mut self.component_types.interface_types,
ty,
)
}
fn record_type(&mut self, record: &[(&str, wasmparser::ComponentValType)]) -> TypeRecordIndex { fn record_type(&mut self, record: &[(&str, wasmparser::ComponentValType)]) -> TypeRecordIndex {
let record = TypeRecord { let record = TypeRecord {
fields: record fields: record
@@ -675,14 +667,14 @@ impl ComponentTypesBuilder {
let e = TypeEnum { let e = TypeEnum {
names: variants.iter().map(|s| s.to_string()).collect(), names: variants.iter().map(|s| s.to_string()).collect(),
}; };
intern(&mut self.enums, &mut self.component_types.enums, e) self.add_enum_type(e)
} }
fn union_type(&mut self, types: &[wasmparser::ComponentValType]) -> TypeUnionIndex { fn union_type(&mut self, types: &[wasmparser::ComponentValType]) -> TypeUnionIndex {
let union = TypeUnion { let union = TypeUnion {
types: types.iter().map(|ty| self.valtype(ty)).collect(), types: types.iter().map(|ty| self.valtype(ty)).collect(),
}; };
intern(&mut self.unions, &mut self.component_types.unions, union) self.add_union_type(union)
} }
fn expected_type( fn expected_type(
@@ -694,11 +686,7 @@ impl ComponentTypesBuilder {
ok: self.valtype(ok), ok: self.valtype(ok),
err: self.valtype(err), err: self.valtype(err),
}; };
intern( self.add_expected_type(expected)
&mut self.expecteds,
&mut self.component_types.expecteds,
expected,
)
} }
/// Interns a new function type within this type information. /// Interns a new function type within this type information.
@@ -720,6 +708,30 @@ impl ComponentTypesBuilder {
pub fn add_variant_type(&mut self, ty: TypeVariant) -> TypeVariantIndex { pub fn add_variant_type(&mut self, ty: TypeVariant) -> TypeVariantIndex {
intern(&mut self.variants, &mut self.component_types.variants, ty) intern(&mut self.variants, &mut self.component_types.variants, ty)
} }
/// Interns a new union type within this type information.
pub fn add_union_type(&mut self, ty: TypeUnion) -> TypeUnionIndex {
intern(&mut self.unions, &mut self.component_types.unions, ty)
}
/// Interns a new enum type within this type information.
pub fn add_enum_type(&mut self, ty: TypeEnum) -> TypeEnumIndex {
intern(&mut self.enums, &mut self.component_types.enums, ty)
}
/// Interns a new expected type within this type information.
pub fn add_expected_type(&mut self, ty: TypeExpected) -> TypeExpectedIndex {
intern(&mut self.expecteds, &mut self.component_types.expecteds, ty)
}
/// Interns a new expected type within this type information.
pub fn add_interface_type(&mut self, ty: InterfaceType) -> TypeInterfaceIndex {
intern(
&mut self.interface_types,
&mut self.component_types.interface_types,
ty,
)
}
} }
// Forward the indexing impl to the internal `TypeTables` // Forward the indexing impl to the internal `TypeTables`

View File

@@ -16,8 +16,9 @@
//! can be somewhat arbitrary, an intentional decision. //! can be somewhat arbitrary, an intentional decision.
use crate::component::{ use crate::component::{
InterfaceType, TypeRecordIndex, TypeTupleIndex, TypeVariantIndex, FLAG_MAY_ENTER, InterfaceType, TypeEnumIndex, TypeExpectedIndex, TypeInterfaceIndex, TypeRecordIndex,
FLAG_MAY_LEAVE, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, TypeTupleIndex, TypeUnionIndex, TypeVariantIndex, FLAG_MAY_ENTER, FLAG_MAY_LEAVE,
MAX_FLAT_PARAMS, MAX_FLAT_RESULTS,
}; };
use crate::fact::core_types::CoreTypes; use crate::fact::core_types::CoreTypes;
use crate::fact::signature::{align_to, Signature}; use crate::fact::signature::{align_to, Signature};
@@ -350,6 +351,10 @@ impl Compiler<'_, '_> {
InterfaceType::Record(t) => self.translate_record(*t, src, dst_ty, dst), InterfaceType::Record(t) => self.translate_record(*t, src, dst_ty, dst),
InterfaceType::Tuple(t) => self.translate_tuple(*t, src, dst_ty, dst), InterfaceType::Tuple(t) => self.translate_tuple(*t, src, dst_ty, dst),
InterfaceType::Variant(v) => self.translate_variant(*v, src, dst_ty, dst), InterfaceType::Variant(v) => self.translate_variant(*v, src, dst_ty, dst),
InterfaceType::Union(u) => self.translate_union(*u, src, dst_ty, dst),
InterfaceType::Enum(t) => self.translate_enum(*t, src, dst_ty, dst),
InterfaceType::Option(t) => self.translate_option(*t, src, dst_ty, dst),
InterfaceType::Expected(t) => self.translate_expected(*t, src, dst_ty, dst),
InterfaceType::String => { InterfaceType::String => {
// consider this field used for now until this is fully // consider this field used for now until this is fully
@@ -686,6 +691,174 @@ impl Compiler<'_, '_> {
let src_disc_size = DiscriminantSize::from_count(src_ty.cases.len()).unwrap(); let src_disc_size = DiscriminantSize::from_count(src_ty.cases.len()).unwrap();
let dst_disc_size = DiscriminantSize::from_count(dst_ty.cases.len()).unwrap(); let dst_disc_size = DiscriminantSize::from_count(dst_ty.cases.len()).unwrap();
let iter = src_ty.cases.iter().enumerate().map(|(src_i, src_case)| {
let dst_i = dst_ty
.cases
.iter()
.position(|c| c.name == src_case.name)
.unwrap();
let dst_case = &dst_ty.cases[dst_i];
let src_i = u32::try_from(src_i).unwrap();
let dst_i = u32::try_from(dst_i).unwrap();
VariantCase {
src_i,
src_ty: &src_case.ty,
dst_i,
dst_ty: &dst_case.ty,
}
});
self.convert_variant(src, src_disc_size, dst, dst_disc_size, iter);
}
fn translate_union(
&mut self,
src_ty: TypeUnionIndex,
src: &Source<'_>,
dst_ty: &InterfaceType,
dst: &Destination,
) {
let src_ty = &self.module.types[src_ty];
let dst_ty = match dst_ty {
InterfaceType::Union(t) => &self.module.types[*t],
_ => panic!("expected an option"),
};
assert_eq!(src_ty.types.len(), dst_ty.types.len());
self.convert_variant(
src,
DiscriminantSize::Size1,
dst,
DiscriminantSize::Size1,
src_ty
.types
.iter()
.zip(dst_ty.types.iter())
.enumerate()
.map(|(i, (src_ty, dst_ty))| {
let i = u32::try_from(i).unwrap();
VariantCase {
src_i: i,
dst_i: i,
src_ty,
dst_ty,
}
}),
);
}
fn translate_enum(
&mut self,
src_ty: TypeEnumIndex,
src: &Source<'_>,
dst_ty: &InterfaceType,
dst: &Destination,
) {
let src_ty = &self.module.types[src_ty];
let dst_ty = match dst_ty {
InterfaceType::Enum(t) => &self.module.types[*t],
_ => panic!("expected an option"),
};
let unit = &InterfaceType::Unit;
self.convert_variant(
src,
DiscriminantSize::from_count(src_ty.names.len()).unwrap(),
dst,
DiscriminantSize::from_count(dst_ty.names.len()).unwrap(),
src_ty.names.iter().enumerate().map(|(src_i, src_name)| {
let dst_i = dst_ty.names.iter().position(|n| n == src_name).unwrap();
let src_i = u32::try_from(src_i).unwrap();
let dst_i = u32::try_from(dst_i).unwrap();
VariantCase {
src_i,
dst_i,
src_ty: unit,
dst_ty: unit,
}
}),
);
}
fn translate_option(
&mut self,
src_ty: TypeInterfaceIndex,
src: &Source<'_>,
dst_ty: &InterfaceType,
dst: &Destination,
) {
let src_ty = &self.module.types[src_ty];
let dst_ty = match dst_ty {
InterfaceType::Option(t) => &self.module.types[*t],
_ => panic!("expected an option"),
};
self.convert_variant(
src,
DiscriminantSize::Size1,
dst,
DiscriminantSize::Size1,
[
VariantCase {
src_i: 0,
dst_i: 0,
src_ty: &InterfaceType::Unit,
dst_ty: &InterfaceType::Unit,
},
VariantCase {
src_i: 1,
dst_i: 1,
src_ty,
dst_ty,
},
]
.into_iter(),
);
}
fn translate_expected(
&mut self,
src_ty: TypeExpectedIndex,
src: &Source<'_>,
dst_ty: &InterfaceType,
dst: &Destination,
) {
let src_ty = &self.module.types[src_ty];
let dst_ty = match dst_ty {
InterfaceType::Expected(t) => &self.module.types[*t],
_ => panic!("expected an expected"),
};
self.convert_variant(
src,
DiscriminantSize::Size1,
dst,
DiscriminantSize::Size1,
[
VariantCase {
src_i: 0,
dst_i: 0,
src_ty: &src_ty.ok,
dst_ty: &dst_ty.ok,
},
VariantCase {
src_i: 1,
dst_i: 1,
src_ty: &src_ty.err,
dst_ty: &dst_ty.err,
},
]
.into_iter(),
);
}
fn convert_variant<'a>(
&mut self,
src: &Source<'_>,
src_disc_size: DiscriminantSize,
dst: &Destination,
dst_disc_size: DiscriminantSize,
src_cases: impl ExactSizeIterator<Item = VariantCase<'a>>,
) {
// The outermost block is special since it has the result type of the // The outermost block is special since it has the result type of the
// translation here. That will depend on the `dst`. // translation here. That will depend on the `dst`.
let outer_block_ty = match dst { let outer_block_ty = match dst {
@@ -703,7 +876,8 @@ impl Compiler<'_, '_> {
// After the outermost block generate a new block for each of the // After the outermost block generate a new block for each of the
// remaining cases. // remaining cases.
for _ in 0..src_ty.cases.len() - 1 { let src_cases_len = src_cases.len();
for _ in 0..src_cases_len - 1 {
self.instruction(Block(BlockType::Empty)); self.instruction(Block(BlockType::Empty));
} }
@@ -727,7 +901,7 @@ impl Compiler<'_, '_> {
// Generate the `br_table` for the discriminant. Each case has an // Generate the `br_table` for the discriminant. Each case has an
// offset of 1 to skip the trapping block. // offset of 1 to skip the trapping block.
let mut targets = Vec::new(); let mut targets = Vec::new();
for i in 0..src_ty.cases.len() { for i in 0..src_cases_len {
targets.push((i + 1) as u32); targets.push((i + 1) as u32);
} }
self.instruction(BrTable(targets[..].into(), 0)); self.instruction(BrTable(targets[..].into(), 0));
@@ -740,19 +914,19 @@ impl Compiler<'_, '_> {
// iteration order here places the first case in the innermost block // iteration order here places the first case in the innermost block
// and the last case in the outermost block. This matches the order // and the last case in the outermost block. This matches the order
// of the jump targets in the `br_table` instruction. // of the jump targets in the `br_table` instruction.
for (src_i, src_case) in src_ty.cases.iter().enumerate() { let src_cases_len = u32::try_from(src_cases_len).unwrap();
let dst_i = dst_ty for case in src_cases {
.cases let VariantCase {
.iter() src_i,
.position(|c| c.name == src_case.name) src_ty,
.unwrap(); dst_i,
let dst_case = &dst_ty.cases[dst_i]; dst_ty,
let dst_i = u32::try_from(dst_i).unwrap() as i32; } = case;
// Translate the discriminant here, noting that `dst_i` may be // Translate the discriminant here, noting that `dst_i` may be
// different than `src_i`. // different than `src_i`.
self.push_dst_addr(dst); self.push_dst_addr(dst);
self.instruction(I32Const(dst_i)); self.instruction(I32Const(dst_i as i32));
match dst { match dst {
Destination::Stack(stack) => self.stack_set(&stack[..1], ValType::I32), Destination::Stack(stack) => self.stack_set(&stack[..1], ValType::I32),
Destination::Memory(mem) => match dst_disc_size { Destination::Memory(mem) => match dst_disc_size {
@@ -764,9 +938,9 @@ impl Compiler<'_, '_> {
// Translate the payload of this case using the various types from // Translate the payload of this case using the various types from
// the dst/src. // the dst/src.
let src_payload = src.payload_src(self.module, src_disc_size, &src_case.ty); let src_payload = src.payload_src(self.module, src_disc_size, src_ty);
let dst_payload = dst.payload_dst(self.module, dst_disc_size, &dst_case.ty); let dst_payload = dst.payload_dst(self.module, dst_disc_size, dst_ty);
self.translate(&src_case.ty, &src_payload, &dst_case.ty, &dst_payload); self.translate(src_ty, &src_payload, dst_ty, &dst_payload);
// If the results of this translation were placed on the stack then // If the results of this translation were placed on the stack then
// the stack values may need to be padded with more zeros due to // the stack values may need to be padded with more zeros due to
@@ -791,9 +965,8 @@ impl Compiler<'_, '_> {
// Branch to the outermost block. Note that this isn't needed for // Branch to the outermost block. Note that this isn't needed for
// the outermost case since it simply falls through. // the outermost case since it simply falls through.
let src_len = src_ty.cases.len(); if src_i != src_cases_len - 1 {
if src_i != src_len - 1 { self.instruction(Br(src_cases_len - src_i - 1));
self.instruction(Br((src_len - src_i - 1) as u32));
} }
self.instruction(End); // end this case's block self.instruction(End); // end this case's block
} }
@@ -1264,3 +1437,10 @@ impl<'a> Stack<'a> {
} }
} }
} }
struct VariantCase<'a> {
src_i: u32,
src_ty: &'a InterfaceType,
dst_i: u32,
dst_ty: &'a InterfaceType,
}