Implement other variant-like types in adapter fusion (#4547)

This commit fills out the adapter fusion compiler for the `union`,
`enum`, `option,` and `result` types. The preexisting support for
`variant` types was refactored slightly to be extensible to all of these
other types and they all now call into the same common translation code.
This commit is contained in:
Alex Crichton
2022-07-28 11:47:15 -05:00
committed by GitHub
parent e1148e43be
commit ce7bbef24d
3 changed files with 258 additions and 36 deletions

View File

@@ -16,6 +16,9 @@ use wasmparser::{Validator, WasmFeatures};
use wasmtime_environ::component::*;
use wasmtime_environ::fact::Module;
// Allow inflating to 16 bits but don't go further.
const MAX_ENUM_SIZE: usize = 257;
#[derive(Arbitrary, Debug)]
struct GenAdapterModule {
debug: bool,
@@ -55,6 +58,10 @@ enum ValType {
Record(Vec<ValType>),
Tuple(Vec<ValType>),
Variant(NonZeroLenVec<ValType>),
Union(NonZeroLenVec<ValType>),
Enum(usize),
Option(Box<ValType>),
Expected(Box<ValType>, Box<ValType>),
}
#[derive(Copy, Clone, Arbitrary, Debug)]
@@ -240,6 +247,29 @@ fn intern(types: &mut ComponentTypesBuilder, ty: &ValType) -> InterfaceType {
};
InterfaceType::Variant(types.add_variant_type(ty))
}
ValType::Union(tys) => {
let ty = TypeUnion {
types: tys.0.iter().map(|ty| intern(types, ty)).collect(),
};
InterfaceType::Union(types.add_union_type(ty))
}
ValType::Enum(size) => {
let size = size % MAX_ENUM_SIZE;
let size = if size == 0 { 1 } else { size };
let ty = TypeEnum {
names: (0..size).map(|i| format!("c{i}")).collect(),
};
InterfaceType::Enum(types.add_enum_type(ty))
}
ValType::Option(ty) => {
let ty = intern(types, ty);
InterfaceType::Option(types.add_interface_type(ty))
}
ValType::Expected(ok, err) => {
let ok = intern(types, ok);
let err = intern(types, err);
InterfaceType::Expected(types.add_expected_type(TypeExpected { ok, err }))
}
}
}

View File

@@ -588,7 +588,7 @@ impl ComponentTypesBuilder {
}
wasmparser::ComponentDefinedType::List(e) => {
let ty = self.valtype(e);
InterfaceType::List(self.intern_interface_type(ty))
InterfaceType::List(self.add_interface_type(ty))
}
wasmparser::ComponentDefinedType::Tuple(e) => InterfaceType::Tuple(self.tuple_type(e)),
wasmparser::ComponentDefinedType::Flags(e) => InterfaceType::Flags(self.flags_type(e)),
@@ -596,7 +596,7 @@ impl ComponentTypesBuilder {
wasmparser::ComponentDefinedType::Union(e) => InterfaceType::Union(self.union_type(e)),
wasmparser::ComponentDefinedType::Option(e) => {
let ty = self.valtype(e);
InterfaceType::Option(self.intern_interface_type(ty))
InterfaceType::Option(self.add_interface_type(ty))
}
wasmparser::ComponentDefinedType::Expected { ok, error } => {
InterfaceType::Expected(self.expected_type(ok, error))
@@ -618,14 +618,6 @@ impl ComponentTypesBuilder {
}
}
fn intern_interface_type(&mut self, ty: InterfaceType) -> TypeInterfaceIndex {
intern(
&mut self.interface_types,
&mut self.component_types.interface_types,
ty,
)
}
fn record_type(&mut self, record: &[(&str, wasmparser::ComponentValType)]) -> TypeRecordIndex {
let record = TypeRecord {
fields: record
@@ -675,14 +667,14 @@ impl ComponentTypesBuilder {
let e = TypeEnum {
names: variants.iter().map(|s| s.to_string()).collect(),
};
intern(&mut self.enums, &mut self.component_types.enums, e)
self.add_enum_type(e)
}
fn union_type(&mut self, types: &[wasmparser::ComponentValType]) -> TypeUnionIndex {
let union = TypeUnion {
types: types.iter().map(|ty| self.valtype(ty)).collect(),
};
intern(&mut self.unions, &mut self.component_types.unions, union)
self.add_union_type(union)
}
fn expected_type(
@@ -694,11 +686,7 @@ impl ComponentTypesBuilder {
ok: self.valtype(ok),
err: self.valtype(err),
};
intern(
&mut self.expecteds,
&mut self.component_types.expecteds,
expected,
)
self.add_expected_type(expected)
}
/// Interns a new function type within this type information.
@@ -720,6 +708,30 @@ impl ComponentTypesBuilder {
pub fn add_variant_type(&mut self, ty: TypeVariant) -> TypeVariantIndex {
intern(&mut self.variants, &mut self.component_types.variants, ty)
}
/// Interns a new union type within this type information.
pub fn add_union_type(&mut self, ty: TypeUnion) -> TypeUnionIndex {
intern(&mut self.unions, &mut self.component_types.unions, ty)
}
/// Interns a new enum type within this type information.
pub fn add_enum_type(&mut self, ty: TypeEnum) -> TypeEnumIndex {
intern(&mut self.enums, &mut self.component_types.enums, ty)
}
/// Interns a new expected type within this type information.
pub fn add_expected_type(&mut self, ty: TypeExpected) -> TypeExpectedIndex {
intern(&mut self.expecteds, &mut self.component_types.expecteds, ty)
}
/// Interns a new expected type within this type information.
pub fn add_interface_type(&mut self, ty: InterfaceType) -> TypeInterfaceIndex {
intern(
&mut self.interface_types,
&mut self.component_types.interface_types,
ty,
)
}
}
// Forward the indexing impl to the internal `TypeTables`

View File

@@ -16,8 +16,9 @@
//! can be somewhat arbitrary, an intentional decision.
use crate::component::{
InterfaceType, TypeRecordIndex, TypeTupleIndex, TypeVariantIndex, FLAG_MAY_ENTER,
FLAG_MAY_LEAVE, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS,
InterfaceType, TypeEnumIndex, TypeExpectedIndex, TypeInterfaceIndex, TypeRecordIndex,
TypeTupleIndex, TypeUnionIndex, TypeVariantIndex, FLAG_MAY_ENTER, FLAG_MAY_LEAVE,
MAX_FLAT_PARAMS, MAX_FLAT_RESULTS,
};
use crate::fact::core_types::CoreTypes;
use crate::fact::signature::{align_to, Signature};
@@ -350,6 +351,10 @@ impl Compiler<'_, '_> {
InterfaceType::Record(t) => self.translate_record(*t, src, dst_ty, dst),
InterfaceType::Tuple(t) => self.translate_tuple(*t, src, dst_ty, dst),
InterfaceType::Variant(v) => self.translate_variant(*v, src, dst_ty, dst),
InterfaceType::Union(u) => self.translate_union(*u, src, dst_ty, dst),
InterfaceType::Enum(t) => self.translate_enum(*t, src, dst_ty, dst),
InterfaceType::Option(t) => self.translate_option(*t, src, dst_ty, dst),
InterfaceType::Expected(t) => self.translate_expected(*t, src, dst_ty, dst),
InterfaceType::String => {
// consider this field used for now until this is fully
@@ -686,6 +691,174 @@ impl Compiler<'_, '_> {
let src_disc_size = DiscriminantSize::from_count(src_ty.cases.len()).unwrap();
let dst_disc_size = DiscriminantSize::from_count(dst_ty.cases.len()).unwrap();
let iter = src_ty.cases.iter().enumerate().map(|(src_i, src_case)| {
let dst_i = dst_ty
.cases
.iter()
.position(|c| c.name == src_case.name)
.unwrap();
let dst_case = &dst_ty.cases[dst_i];
let src_i = u32::try_from(src_i).unwrap();
let dst_i = u32::try_from(dst_i).unwrap();
VariantCase {
src_i,
src_ty: &src_case.ty,
dst_i,
dst_ty: &dst_case.ty,
}
});
self.convert_variant(src, src_disc_size, dst, dst_disc_size, iter);
}
fn translate_union(
&mut self,
src_ty: TypeUnionIndex,
src: &Source<'_>,
dst_ty: &InterfaceType,
dst: &Destination,
) {
let src_ty = &self.module.types[src_ty];
let dst_ty = match dst_ty {
InterfaceType::Union(t) => &self.module.types[*t],
_ => panic!("expected an option"),
};
assert_eq!(src_ty.types.len(), dst_ty.types.len());
self.convert_variant(
src,
DiscriminantSize::Size1,
dst,
DiscriminantSize::Size1,
src_ty
.types
.iter()
.zip(dst_ty.types.iter())
.enumerate()
.map(|(i, (src_ty, dst_ty))| {
let i = u32::try_from(i).unwrap();
VariantCase {
src_i: i,
dst_i: i,
src_ty,
dst_ty,
}
}),
);
}
fn translate_enum(
&mut self,
src_ty: TypeEnumIndex,
src: &Source<'_>,
dst_ty: &InterfaceType,
dst: &Destination,
) {
let src_ty = &self.module.types[src_ty];
let dst_ty = match dst_ty {
InterfaceType::Enum(t) => &self.module.types[*t],
_ => panic!("expected an option"),
};
let unit = &InterfaceType::Unit;
self.convert_variant(
src,
DiscriminantSize::from_count(src_ty.names.len()).unwrap(),
dst,
DiscriminantSize::from_count(dst_ty.names.len()).unwrap(),
src_ty.names.iter().enumerate().map(|(src_i, src_name)| {
let dst_i = dst_ty.names.iter().position(|n| n == src_name).unwrap();
let src_i = u32::try_from(src_i).unwrap();
let dst_i = u32::try_from(dst_i).unwrap();
VariantCase {
src_i,
dst_i,
src_ty: unit,
dst_ty: unit,
}
}),
);
}
fn translate_option(
&mut self,
src_ty: TypeInterfaceIndex,
src: &Source<'_>,
dst_ty: &InterfaceType,
dst: &Destination,
) {
let src_ty = &self.module.types[src_ty];
let dst_ty = match dst_ty {
InterfaceType::Option(t) => &self.module.types[*t],
_ => panic!("expected an option"),
};
self.convert_variant(
src,
DiscriminantSize::Size1,
dst,
DiscriminantSize::Size1,
[
VariantCase {
src_i: 0,
dst_i: 0,
src_ty: &InterfaceType::Unit,
dst_ty: &InterfaceType::Unit,
},
VariantCase {
src_i: 1,
dst_i: 1,
src_ty,
dst_ty,
},
]
.into_iter(),
);
}
fn translate_expected(
&mut self,
src_ty: TypeExpectedIndex,
src: &Source<'_>,
dst_ty: &InterfaceType,
dst: &Destination,
) {
let src_ty = &self.module.types[src_ty];
let dst_ty = match dst_ty {
InterfaceType::Expected(t) => &self.module.types[*t],
_ => panic!("expected an expected"),
};
self.convert_variant(
src,
DiscriminantSize::Size1,
dst,
DiscriminantSize::Size1,
[
VariantCase {
src_i: 0,
dst_i: 0,
src_ty: &src_ty.ok,
dst_ty: &dst_ty.ok,
},
VariantCase {
src_i: 1,
dst_i: 1,
src_ty: &src_ty.err,
dst_ty: &dst_ty.err,
},
]
.into_iter(),
);
}
fn convert_variant<'a>(
&mut self,
src: &Source<'_>,
src_disc_size: DiscriminantSize,
dst: &Destination,
dst_disc_size: DiscriminantSize,
src_cases: impl ExactSizeIterator<Item = VariantCase<'a>>,
) {
// The outermost block is special since it has the result type of the
// translation here. That will depend on the `dst`.
let outer_block_ty = match dst {
@@ -703,7 +876,8 @@ impl Compiler<'_, '_> {
// After the outermost block generate a new block for each of the
// remaining cases.
for _ in 0..src_ty.cases.len() - 1 {
let src_cases_len = src_cases.len();
for _ in 0..src_cases_len - 1 {
self.instruction(Block(BlockType::Empty));
}
@@ -727,7 +901,7 @@ impl Compiler<'_, '_> {
// Generate the `br_table` for the discriminant. Each case has an
// offset of 1 to skip the trapping block.
let mut targets = Vec::new();
for i in 0..src_ty.cases.len() {
for i in 0..src_cases_len {
targets.push((i + 1) as u32);
}
self.instruction(BrTable(targets[..].into(), 0));
@@ -740,19 +914,19 @@ impl Compiler<'_, '_> {
// iteration order here places the first case in the innermost block
// and the last case in the outermost block. This matches the order
// of the jump targets in the `br_table` instruction.
for (src_i, src_case) in src_ty.cases.iter().enumerate() {
let dst_i = dst_ty
.cases
.iter()
.position(|c| c.name == src_case.name)
.unwrap();
let dst_case = &dst_ty.cases[dst_i];
let dst_i = u32::try_from(dst_i).unwrap() as i32;
let src_cases_len = u32::try_from(src_cases_len).unwrap();
for case in src_cases {
let VariantCase {
src_i,
src_ty,
dst_i,
dst_ty,
} = case;
// Translate the discriminant here, noting that `dst_i` may be
// different than `src_i`.
self.push_dst_addr(dst);
self.instruction(I32Const(dst_i));
self.instruction(I32Const(dst_i as i32));
match dst {
Destination::Stack(stack) => self.stack_set(&stack[..1], ValType::I32),
Destination::Memory(mem) => match dst_disc_size {
@@ -764,9 +938,9 @@ impl Compiler<'_, '_> {
// Translate the payload of this case using the various types from
// the dst/src.
let src_payload = src.payload_src(self.module, src_disc_size, &src_case.ty);
let dst_payload = dst.payload_dst(self.module, dst_disc_size, &dst_case.ty);
self.translate(&src_case.ty, &src_payload, &dst_case.ty, &dst_payload);
let src_payload = src.payload_src(self.module, src_disc_size, src_ty);
let dst_payload = dst.payload_dst(self.module, dst_disc_size, dst_ty);
self.translate(src_ty, &src_payload, dst_ty, &dst_payload);
// If the results of this translation were placed on the stack then
// the stack values may need to be padded with more zeros due to
@@ -791,9 +965,8 @@ impl Compiler<'_, '_> {
// Branch to the outermost block. Note that this isn't needed for
// the outermost case since it simply falls through.
let src_len = src_ty.cases.len();
if src_i != src_len - 1 {
self.instruction(Br((src_len - src_i - 1) as u32));
if src_i != src_cases_len - 1 {
self.instruction(Br(src_cases_len - src_i - 1));
}
self.instruction(End); // end this case's block
}
@@ -1264,3 +1437,10 @@ impl<'a> Stack<'a> {
}
}
}
struct VariantCase<'a> {
src_i: u32,
src_ty: &'a InterfaceType,
dst_i: u32,
dst_ty: &'a InterfaceType,
}