Implement other variant-like types in adapter fusion (#4547)
This commit fills out the adapter fusion compiler for the `union`, `enum`, `option,` and `result` types. The preexisting support for `variant` types was refactored slightly to be extensible to all of these other types and they all now call into the same common translation code.
This commit is contained in:
@@ -16,6 +16,9 @@ use wasmparser::{Validator, WasmFeatures};
|
||||
use wasmtime_environ::component::*;
|
||||
use wasmtime_environ::fact::Module;
|
||||
|
||||
// Allow inflating to 16 bits but don't go further.
|
||||
const MAX_ENUM_SIZE: usize = 257;
|
||||
|
||||
#[derive(Arbitrary, Debug)]
|
||||
struct GenAdapterModule {
|
||||
debug: bool,
|
||||
@@ -55,6 +58,10 @@ enum ValType {
|
||||
Record(Vec<ValType>),
|
||||
Tuple(Vec<ValType>),
|
||||
Variant(NonZeroLenVec<ValType>),
|
||||
Union(NonZeroLenVec<ValType>),
|
||||
Enum(usize),
|
||||
Option(Box<ValType>),
|
||||
Expected(Box<ValType>, Box<ValType>),
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Arbitrary, Debug)]
|
||||
@@ -240,6 +247,29 @@ fn intern(types: &mut ComponentTypesBuilder, ty: &ValType) -> InterfaceType {
|
||||
};
|
||||
InterfaceType::Variant(types.add_variant_type(ty))
|
||||
}
|
||||
ValType::Union(tys) => {
|
||||
let ty = TypeUnion {
|
||||
types: tys.0.iter().map(|ty| intern(types, ty)).collect(),
|
||||
};
|
||||
InterfaceType::Union(types.add_union_type(ty))
|
||||
}
|
||||
ValType::Enum(size) => {
|
||||
let size = size % MAX_ENUM_SIZE;
|
||||
let size = if size == 0 { 1 } else { size };
|
||||
let ty = TypeEnum {
|
||||
names: (0..size).map(|i| format!("c{i}")).collect(),
|
||||
};
|
||||
InterfaceType::Enum(types.add_enum_type(ty))
|
||||
}
|
||||
ValType::Option(ty) => {
|
||||
let ty = intern(types, ty);
|
||||
InterfaceType::Option(types.add_interface_type(ty))
|
||||
}
|
||||
ValType::Expected(ok, err) => {
|
||||
let ok = intern(types, ok);
|
||||
let err = intern(types, err);
|
||||
InterfaceType::Expected(types.add_expected_type(TypeExpected { ok, err }))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -588,7 +588,7 @@ impl ComponentTypesBuilder {
|
||||
}
|
||||
wasmparser::ComponentDefinedType::List(e) => {
|
||||
let ty = self.valtype(e);
|
||||
InterfaceType::List(self.intern_interface_type(ty))
|
||||
InterfaceType::List(self.add_interface_type(ty))
|
||||
}
|
||||
wasmparser::ComponentDefinedType::Tuple(e) => InterfaceType::Tuple(self.tuple_type(e)),
|
||||
wasmparser::ComponentDefinedType::Flags(e) => InterfaceType::Flags(self.flags_type(e)),
|
||||
@@ -596,7 +596,7 @@ impl ComponentTypesBuilder {
|
||||
wasmparser::ComponentDefinedType::Union(e) => InterfaceType::Union(self.union_type(e)),
|
||||
wasmparser::ComponentDefinedType::Option(e) => {
|
||||
let ty = self.valtype(e);
|
||||
InterfaceType::Option(self.intern_interface_type(ty))
|
||||
InterfaceType::Option(self.add_interface_type(ty))
|
||||
}
|
||||
wasmparser::ComponentDefinedType::Expected { ok, error } => {
|
||||
InterfaceType::Expected(self.expected_type(ok, error))
|
||||
@@ -618,14 +618,6 @@ impl ComponentTypesBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
fn intern_interface_type(&mut self, ty: InterfaceType) -> TypeInterfaceIndex {
|
||||
intern(
|
||||
&mut self.interface_types,
|
||||
&mut self.component_types.interface_types,
|
||||
ty,
|
||||
)
|
||||
}
|
||||
|
||||
fn record_type(&mut self, record: &[(&str, wasmparser::ComponentValType)]) -> TypeRecordIndex {
|
||||
let record = TypeRecord {
|
||||
fields: record
|
||||
@@ -675,14 +667,14 @@ impl ComponentTypesBuilder {
|
||||
let e = TypeEnum {
|
||||
names: variants.iter().map(|s| s.to_string()).collect(),
|
||||
};
|
||||
intern(&mut self.enums, &mut self.component_types.enums, e)
|
||||
self.add_enum_type(e)
|
||||
}
|
||||
|
||||
fn union_type(&mut self, types: &[wasmparser::ComponentValType]) -> TypeUnionIndex {
|
||||
let union = TypeUnion {
|
||||
types: types.iter().map(|ty| self.valtype(ty)).collect(),
|
||||
};
|
||||
intern(&mut self.unions, &mut self.component_types.unions, union)
|
||||
self.add_union_type(union)
|
||||
}
|
||||
|
||||
fn expected_type(
|
||||
@@ -694,11 +686,7 @@ impl ComponentTypesBuilder {
|
||||
ok: self.valtype(ok),
|
||||
err: self.valtype(err),
|
||||
};
|
||||
intern(
|
||||
&mut self.expecteds,
|
||||
&mut self.component_types.expecteds,
|
||||
expected,
|
||||
)
|
||||
self.add_expected_type(expected)
|
||||
}
|
||||
|
||||
/// Interns a new function type within this type information.
|
||||
@@ -720,6 +708,30 @@ impl ComponentTypesBuilder {
|
||||
pub fn add_variant_type(&mut self, ty: TypeVariant) -> TypeVariantIndex {
|
||||
intern(&mut self.variants, &mut self.component_types.variants, ty)
|
||||
}
|
||||
|
||||
/// Interns a new union type within this type information.
|
||||
pub fn add_union_type(&mut self, ty: TypeUnion) -> TypeUnionIndex {
|
||||
intern(&mut self.unions, &mut self.component_types.unions, ty)
|
||||
}
|
||||
|
||||
/// Interns a new enum type within this type information.
|
||||
pub fn add_enum_type(&mut self, ty: TypeEnum) -> TypeEnumIndex {
|
||||
intern(&mut self.enums, &mut self.component_types.enums, ty)
|
||||
}
|
||||
|
||||
/// Interns a new expected type within this type information.
|
||||
pub fn add_expected_type(&mut self, ty: TypeExpected) -> TypeExpectedIndex {
|
||||
intern(&mut self.expecteds, &mut self.component_types.expecteds, ty)
|
||||
}
|
||||
|
||||
/// Interns a new expected type within this type information.
|
||||
pub fn add_interface_type(&mut self, ty: InterfaceType) -> TypeInterfaceIndex {
|
||||
intern(
|
||||
&mut self.interface_types,
|
||||
&mut self.component_types.interface_types,
|
||||
ty,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Forward the indexing impl to the internal `TypeTables`
|
||||
|
||||
@@ -16,8 +16,9 @@
|
||||
//! can be somewhat arbitrary, an intentional decision.
|
||||
|
||||
use crate::component::{
|
||||
InterfaceType, TypeRecordIndex, TypeTupleIndex, TypeVariantIndex, FLAG_MAY_ENTER,
|
||||
FLAG_MAY_LEAVE, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS,
|
||||
InterfaceType, TypeEnumIndex, TypeExpectedIndex, TypeInterfaceIndex, TypeRecordIndex,
|
||||
TypeTupleIndex, TypeUnionIndex, TypeVariantIndex, FLAG_MAY_ENTER, FLAG_MAY_LEAVE,
|
||||
MAX_FLAT_PARAMS, MAX_FLAT_RESULTS,
|
||||
};
|
||||
use crate::fact::core_types::CoreTypes;
|
||||
use crate::fact::signature::{align_to, Signature};
|
||||
@@ -350,6 +351,10 @@ impl Compiler<'_, '_> {
|
||||
InterfaceType::Record(t) => self.translate_record(*t, src, dst_ty, dst),
|
||||
InterfaceType::Tuple(t) => self.translate_tuple(*t, src, dst_ty, dst),
|
||||
InterfaceType::Variant(v) => self.translate_variant(*v, src, dst_ty, dst),
|
||||
InterfaceType::Union(u) => self.translate_union(*u, src, dst_ty, dst),
|
||||
InterfaceType::Enum(t) => self.translate_enum(*t, src, dst_ty, dst),
|
||||
InterfaceType::Option(t) => self.translate_option(*t, src, dst_ty, dst),
|
||||
InterfaceType::Expected(t) => self.translate_expected(*t, src, dst_ty, dst),
|
||||
|
||||
InterfaceType::String => {
|
||||
// consider this field used for now until this is fully
|
||||
@@ -686,6 +691,174 @@ impl Compiler<'_, '_> {
|
||||
let src_disc_size = DiscriminantSize::from_count(src_ty.cases.len()).unwrap();
|
||||
let dst_disc_size = DiscriminantSize::from_count(dst_ty.cases.len()).unwrap();
|
||||
|
||||
let iter = src_ty.cases.iter().enumerate().map(|(src_i, src_case)| {
|
||||
let dst_i = dst_ty
|
||||
.cases
|
||||
.iter()
|
||||
.position(|c| c.name == src_case.name)
|
||||
.unwrap();
|
||||
let dst_case = &dst_ty.cases[dst_i];
|
||||
let src_i = u32::try_from(src_i).unwrap();
|
||||
let dst_i = u32::try_from(dst_i).unwrap();
|
||||
VariantCase {
|
||||
src_i,
|
||||
src_ty: &src_case.ty,
|
||||
dst_i,
|
||||
dst_ty: &dst_case.ty,
|
||||
}
|
||||
});
|
||||
self.convert_variant(src, src_disc_size, dst, dst_disc_size, iter);
|
||||
}
|
||||
|
||||
fn translate_union(
|
||||
&mut self,
|
||||
src_ty: TypeUnionIndex,
|
||||
src: &Source<'_>,
|
||||
dst_ty: &InterfaceType,
|
||||
dst: &Destination,
|
||||
) {
|
||||
let src_ty = &self.module.types[src_ty];
|
||||
let dst_ty = match dst_ty {
|
||||
InterfaceType::Union(t) => &self.module.types[*t],
|
||||
_ => panic!("expected an option"),
|
||||
};
|
||||
assert_eq!(src_ty.types.len(), dst_ty.types.len());
|
||||
|
||||
self.convert_variant(
|
||||
src,
|
||||
DiscriminantSize::Size1,
|
||||
dst,
|
||||
DiscriminantSize::Size1,
|
||||
src_ty
|
||||
.types
|
||||
.iter()
|
||||
.zip(dst_ty.types.iter())
|
||||
.enumerate()
|
||||
.map(|(i, (src_ty, dst_ty))| {
|
||||
let i = u32::try_from(i).unwrap();
|
||||
VariantCase {
|
||||
src_i: i,
|
||||
dst_i: i,
|
||||
src_ty,
|
||||
dst_ty,
|
||||
}
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
fn translate_enum(
|
||||
&mut self,
|
||||
src_ty: TypeEnumIndex,
|
||||
src: &Source<'_>,
|
||||
dst_ty: &InterfaceType,
|
||||
dst: &Destination,
|
||||
) {
|
||||
let src_ty = &self.module.types[src_ty];
|
||||
let dst_ty = match dst_ty {
|
||||
InterfaceType::Enum(t) => &self.module.types[*t],
|
||||
_ => panic!("expected an option"),
|
||||
};
|
||||
|
||||
let unit = &InterfaceType::Unit;
|
||||
self.convert_variant(
|
||||
src,
|
||||
DiscriminantSize::from_count(src_ty.names.len()).unwrap(),
|
||||
dst,
|
||||
DiscriminantSize::from_count(dst_ty.names.len()).unwrap(),
|
||||
src_ty.names.iter().enumerate().map(|(src_i, src_name)| {
|
||||
let dst_i = dst_ty.names.iter().position(|n| n == src_name).unwrap();
|
||||
let src_i = u32::try_from(src_i).unwrap();
|
||||
let dst_i = u32::try_from(dst_i).unwrap();
|
||||
VariantCase {
|
||||
src_i,
|
||||
dst_i,
|
||||
src_ty: unit,
|
||||
dst_ty: unit,
|
||||
}
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
fn translate_option(
|
||||
&mut self,
|
||||
src_ty: TypeInterfaceIndex,
|
||||
src: &Source<'_>,
|
||||
dst_ty: &InterfaceType,
|
||||
dst: &Destination,
|
||||
) {
|
||||
let src_ty = &self.module.types[src_ty];
|
||||
let dst_ty = match dst_ty {
|
||||
InterfaceType::Option(t) => &self.module.types[*t],
|
||||
_ => panic!("expected an option"),
|
||||
};
|
||||
|
||||
self.convert_variant(
|
||||
src,
|
||||
DiscriminantSize::Size1,
|
||||
dst,
|
||||
DiscriminantSize::Size1,
|
||||
[
|
||||
VariantCase {
|
||||
src_i: 0,
|
||||
dst_i: 0,
|
||||
src_ty: &InterfaceType::Unit,
|
||||
dst_ty: &InterfaceType::Unit,
|
||||
},
|
||||
VariantCase {
|
||||
src_i: 1,
|
||||
dst_i: 1,
|
||||
src_ty,
|
||||
dst_ty,
|
||||
},
|
||||
]
|
||||
.into_iter(),
|
||||
);
|
||||
}
|
||||
|
||||
fn translate_expected(
|
||||
&mut self,
|
||||
src_ty: TypeExpectedIndex,
|
||||
src: &Source<'_>,
|
||||
dst_ty: &InterfaceType,
|
||||
dst: &Destination,
|
||||
) {
|
||||
let src_ty = &self.module.types[src_ty];
|
||||
let dst_ty = match dst_ty {
|
||||
InterfaceType::Expected(t) => &self.module.types[*t],
|
||||
_ => panic!("expected an expected"),
|
||||
};
|
||||
|
||||
self.convert_variant(
|
||||
src,
|
||||
DiscriminantSize::Size1,
|
||||
dst,
|
||||
DiscriminantSize::Size1,
|
||||
[
|
||||
VariantCase {
|
||||
src_i: 0,
|
||||
dst_i: 0,
|
||||
src_ty: &src_ty.ok,
|
||||
dst_ty: &dst_ty.ok,
|
||||
},
|
||||
VariantCase {
|
||||
src_i: 1,
|
||||
dst_i: 1,
|
||||
src_ty: &src_ty.err,
|
||||
dst_ty: &dst_ty.err,
|
||||
},
|
||||
]
|
||||
.into_iter(),
|
||||
);
|
||||
}
|
||||
|
||||
fn convert_variant<'a>(
|
||||
&mut self,
|
||||
src: &Source<'_>,
|
||||
src_disc_size: DiscriminantSize,
|
||||
dst: &Destination,
|
||||
dst_disc_size: DiscriminantSize,
|
||||
src_cases: impl ExactSizeIterator<Item = VariantCase<'a>>,
|
||||
) {
|
||||
// The outermost block is special since it has the result type of the
|
||||
// translation here. That will depend on the `dst`.
|
||||
let outer_block_ty = match dst {
|
||||
@@ -703,7 +876,8 @@ impl Compiler<'_, '_> {
|
||||
|
||||
// After the outermost block generate a new block for each of the
|
||||
// remaining cases.
|
||||
for _ in 0..src_ty.cases.len() - 1 {
|
||||
let src_cases_len = src_cases.len();
|
||||
for _ in 0..src_cases_len - 1 {
|
||||
self.instruction(Block(BlockType::Empty));
|
||||
}
|
||||
|
||||
@@ -727,7 +901,7 @@ impl Compiler<'_, '_> {
|
||||
// Generate the `br_table` for the discriminant. Each case has an
|
||||
// offset of 1 to skip the trapping block.
|
||||
let mut targets = Vec::new();
|
||||
for i in 0..src_ty.cases.len() {
|
||||
for i in 0..src_cases_len {
|
||||
targets.push((i + 1) as u32);
|
||||
}
|
||||
self.instruction(BrTable(targets[..].into(), 0));
|
||||
@@ -740,19 +914,19 @@ impl Compiler<'_, '_> {
|
||||
// iteration order here places the first case in the innermost block
|
||||
// and the last case in the outermost block. This matches the order
|
||||
// of the jump targets in the `br_table` instruction.
|
||||
for (src_i, src_case) in src_ty.cases.iter().enumerate() {
|
||||
let dst_i = dst_ty
|
||||
.cases
|
||||
.iter()
|
||||
.position(|c| c.name == src_case.name)
|
||||
.unwrap();
|
||||
let dst_case = &dst_ty.cases[dst_i];
|
||||
let dst_i = u32::try_from(dst_i).unwrap() as i32;
|
||||
let src_cases_len = u32::try_from(src_cases_len).unwrap();
|
||||
for case in src_cases {
|
||||
let VariantCase {
|
||||
src_i,
|
||||
src_ty,
|
||||
dst_i,
|
||||
dst_ty,
|
||||
} = case;
|
||||
|
||||
// Translate the discriminant here, noting that `dst_i` may be
|
||||
// different than `src_i`.
|
||||
self.push_dst_addr(dst);
|
||||
self.instruction(I32Const(dst_i));
|
||||
self.instruction(I32Const(dst_i as i32));
|
||||
match dst {
|
||||
Destination::Stack(stack) => self.stack_set(&stack[..1], ValType::I32),
|
||||
Destination::Memory(mem) => match dst_disc_size {
|
||||
@@ -764,9 +938,9 @@ impl Compiler<'_, '_> {
|
||||
|
||||
// Translate the payload of this case using the various types from
|
||||
// the dst/src.
|
||||
let src_payload = src.payload_src(self.module, src_disc_size, &src_case.ty);
|
||||
let dst_payload = dst.payload_dst(self.module, dst_disc_size, &dst_case.ty);
|
||||
self.translate(&src_case.ty, &src_payload, &dst_case.ty, &dst_payload);
|
||||
let src_payload = src.payload_src(self.module, src_disc_size, src_ty);
|
||||
let dst_payload = dst.payload_dst(self.module, dst_disc_size, dst_ty);
|
||||
self.translate(src_ty, &src_payload, dst_ty, &dst_payload);
|
||||
|
||||
// If the results of this translation were placed on the stack then
|
||||
// the stack values may need to be padded with more zeros due to
|
||||
@@ -791,9 +965,8 @@ impl Compiler<'_, '_> {
|
||||
|
||||
// Branch to the outermost block. Note that this isn't needed for
|
||||
// the outermost case since it simply falls through.
|
||||
let src_len = src_ty.cases.len();
|
||||
if src_i != src_len - 1 {
|
||||
self.instruction(Br((src_len - src_i - 1) as u32));
|
||||
if src_i != src_cases_len - 1 {
|
||||
self.instruction(Br(src_cases_len - src_i - 1));
|
||||
}
|
||||
self.instruction(End); // end this case's block
|
||||
}
|
||||
@@ -1264,3 +1437,10 @@ impl<'a> Stack<'a> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct VariantCase<'a> {
|
||||
src_i: u32,
|
||||
src_ty: &'a InterfaceType,
|
||||
dst_i: u32,
|
||||
dst_ty: &'a InterfaceType,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user