Implement flags in fused adapters (#4549)

This implements the `flags` type for fused adapters and converting
between modules. The main logic here is handling the variable size of
flags in addition to the masking which happens to ignore unrelated bits
when the values pass through the canonical ABI.
This commit is contained in:
Alex Crichton
2022-07-28 14:56:32 -05:00
committed by GitHub
parent fb7d51033c
commit 32979b2714
4 changed files with 280 additions and 19 deletions

View File

@@ -16,9 +16,6 @@ use wasmparser::{Validator, WasmFeatures};
use wasmtime_environ::component::*;
use wasmtime_environ::fact::Module;
// Allow inflating to 16 bits but don't go further.
const MAX_ENUM_SIZE: usize = 257;
#[derive(Arbitrary, Debug)]
struct GenAdapterModule {
debug: bool,
@@ -56,10 +53,16 @@ enum ValType {
Float64,
Char,
Record(Vec<ValType>),
// FIXME(WebAssembly/component-model#75) are zero-sized flags allowed?
//
// ... otherwise go up to 65 flags to exercise up to 3 u32 values
Flags(UsizeInRange<1, 65>),
Tuple(Vec<ValType>),
Variant(NonZeroLenVec<ValType>),
Union(NonZeroLenVec<ValType>),
Enum(usize),
// at least one enum variant but no more than what's necessary to inflate to
// 16 bits to keep this reasonably sized
Enum(UsizeInRange<1, 257>),
Option(Box<ValType>),
Expected(Box<ValType>, Box<ValType>),
}
@@ -89,6 +92,20 @@ impl<T: fmt::Debug> fmt::Debug for NonZeroLenVec<T> {
}
}
pub struct UsizeInRange<const L: usize, const H: usize>(usize);
impl<'a, const L: usize, const H: usize> Arbitrary<'a> for UsizeInRange<L, H> {
fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
Ok(UsizeInRange(u.int_in_range(L..=H)?))
}
}
impl<const L: usize, const H: usize> fmt::Debug for UsizeInRange<L, H> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
fuzz_target!(|module: GenAdapterModule| {
drop(env_logger::try_init());
@@ -228,6 +245,12 @@ fn intern(types: &mut ComponentTypesBuilder, ty: &ValType) -> InterfaceType {
};
InterfaceType::Record(types.add_record_type(ty))
}
ValType::Flags(size) => {
let ty = TypeFlags {
names: (0..size.0).map(|i| format!("f{i}")).collect(),
};
InterfaceType::Flags(types.add_flags_type(ty))
}
ValType::Tuple(tys) => {
let ty = TypeTuple {
types: tys.iter().map(|ty| intern(types, ty)).collect(),
@@ -254,10 +277,8 @@ fn intern(types: &mut ComponentTypesBuilder, ty: &ValType) -> InterfaceType {
InterfaceType::Union(types.add_union_type(ty))
}
ValType::Enum(size) => {
let size = size % MAX_ENUM_SIZE;
let size = if size == 0 { 1 } else { size };
let ty = TypeEnum {
names: (0..size).map(|i| format!("c{i}")).collect(),
names: (0..size.0).map(|i| format!("c{i}")).collect(),
};
InterfaceType::Enum(types.add_enum_type(ty))
}

View File

@@ -660,7 +660,7 @@ impl ComponentTypesBuilder {
let flags = TypeFlags {
names: flags.iter().map(|s| s.to_string()).collect(),
};
intern(&mut self.flags, &mut self.component_types.flags, flags)
self.add_flags_type(flags)
}
fn enum_type(&mut self, variants: &[&str]) -> TypeEnumIndex {
@@ -699,6 +699,11 @@ impl ComponentTypesBuilder {
intern(&mut self.records, &mut self.component_types.records, ty)
}
/// Interns a new flags type within this type information.
pub fn add_flags_type(&mut self, ty: TypeFlags) -> TypeFlagsIndex {
intern(&mut self.flags, &mut self.component_types.flags, ty)
}
/// Interns a new tuple type within this type information.
pub fn add_tuple_type(&mut self, ty: TypeTuple) -> TypeTupleIndex {
intern(&mut self.tuples, &mut self.component_types.tuples, ty)

View File

@@ -16,9 +16,9 @@
//! can be somewhat arbitrary, an intentional decision.
use crate::component::{
InterfaceType, TypeEnumIndex, TypeExpectedIndex, TypeInterfaceIndex, TypeRecordIndex,
TypeTupleIndex, TypeUnionIndex, TypeVariantIndex, FLAG_MAY_ENTER, FLAG_MAY_LEAVE,
MAX_FLAT_PARAMS, MAX_FLAT_RESULTS,
InterfaceType, TypeEnumIndex, TypeExpectedIndex, TypeFlagsIndex, TypeInterfaceIndex,
TypeRecordIndex, TypeTupleIndex, TypeUnionIndex, TypeVariantIndex, FLAG_MAY_ENTER,
FLAG_MAY_LEAVE, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS,
};
use crate::fact::core_types::CoreTypes;
use crate::fact::signature::{align_to, Signature};
@@ -29,7 +29,7 @@ use std::collections::HashMap;
use std::mem;
use std::ops::Range;
use wasm_encoder::{BlockType, Encode, Instruction, Instruction::*, MemArg, ValType};
use wasmtime_component_util::DiscriminantSize;
use wasmtime_component_util::{DiscriminantSize, FlagsSize};
struct Compiler<'a, 'b> {
/// The module that the adapter will eventually be inserted into.
@@ -349,6 +349,7 @@ impl Compiler<'_, '_> {
InterfaceType::Float64 => self.translate_f64(src, dst_ty, dst),
InterfaceType::Char => self.translate_char(src, dst_ty, dst),
InterfaceType::Record(t) => self.translate_record(*t, src, dst_ty, dst),
InterfaceType::Flags(f) => self.translate_flags(*f, src, dst_ty, dst),
InterfaceType::Tuple(t) => self.translate_tuple(*t, src, dst_ty, dst),
InterfaceType::Variant(v) => self.translate_variant(*v, src, dst_ty, dst),
InterfaceType::Union(u) => self.translate_union(*u, src, dst_ty, dst),
@@ -399,15 +400,25 @@ impl Compiler<'_, '_> {
fn translate_u8(&mut self, src: &Source<'_>, dst_ty: &InterfaceType, dst: &Destination) {
// TODO: subtyping
assert!(matches!(dst_ty, InterfaceType::U8));
self.convert_u8_mask(src, dst, 0xff);
}
fn convert_u8_mask(&mut self, src: &Source<'_>, dst: &Destination<'_>, mask: u8) {
self.push_dst_addr(dst);
let mut needs_mask = true;
match src {
Source::Memory(mem) => self.i32_load8u(mem),
Source::Memory(mem) => {
self.i32_load8u(mem);
needs_mask = mask != 0xff;
}
Source::Stack(stack) => {
self.stack_get(stack, ValType::I32);
self.instruction(I32Const(0xff));
self.instruction(I32And);
}
}
if needs_mask {
self.instruction(I32Const(i32::from(mask)));
self.instruction(I32And);
}
match dst {
Destination::Memory(mem) => self.i32_store8(mem),
Destination::Stack(stack) => self.stack_set(stack, ValType::I32),
@@ -434,15 +445,25 @@ impl Compiler<'_, '_> {
fn translate_u16(&mut self, src: &Source<'_>, dst_ty: &InterfaceType, dst: &Destination) {
// TODO: subtyping
assert!(matches!(dst_ty, InterfaceType::U16));
self.convert_u16_mask(src, dst, 0xffff);
}
fn convert_u16_mask(&mut self, src: &Source<'_>, dst: &Destination<'_>, mask: u16) {
self.push_dst_addr(dst);
let mut needs_mask = true;
match src {
Source::Memory(mem) => self.i32_load16u(mem),
Source::Memory(mem) => {
self.i32_load16u(mem);
needs_mask = mask != 0xffff;
}
Source::Stack(stack) => {
self.stack_get(stack, ValType::I32);
self.instruction(I32Const(0xffff));
self.instruction(I32And);
}
}
if needs_mask {
self.instruction(I32Const(i32::from(mask)));
self.instruction(I32And);
}
match dst {
Destination::Memory(mem) => self.i32_store16(mem),
Destination::Stack(stack) => self.stack_set(stack, ValType::I32),
@@ -469,11 +490,19 @@ impl Compiler<'_, '_> {
fn translate_u32(&mut self, src: &Source<'_>, dst_ty: &InterfaceType, dst: &Destination) {
// TODO: subtyping
assert!(matches!(dst_ty, InterfaceType::U32));
self.convert_u32_mask(src, dst, 0xffffffff)
}
fn convert_u32_mask(&mut self, src: &Source<'_>, dst: &Destination<'_>, mask: u32) {
self.push_dst_addr(dst);
match src {
Source::Memory(mem) => self.i32_load(mem),
Source::Memory(mem) => self.i32_load16u(mem),
Source::Stack(stack) => self.stack_get(stack, ValType::I32),
}
if mask != 0xffffffff {
self.instruction(I32Const(mask as i32));
self.instruction(I32And);
}
match dst {
Destination::Memory(mem) => self.i32_store(mem),
Destination::Stack(stack) => self.stack_set(stack, ValType::I32),
@@ -648,6 +677,52 @@ impl Compiler<'_, '_> {
}
}
fn translate_flags(
&mut self,
src_ty: TypeFlagsIndex,
src: &Source<'_>,
dst_ty: &InterfaceType,
dst: &Destination,
) {
let src_ty = &self.module.types[src_ty];
let dst_ty = match dst_ty {
InterfaceType::Flags(r) => &self.module.types[*r],
_ => panic!("expected a record"),
};
// TODO: subtyping
//
// Notably this implementation does not support reordering flags from
// the source to the destination nor having more flags in the
// destination. Currently this is a copy from source to destination
// in-bulk. Otherwise reordering indices would have to have some sort of
// fancy bit twiddling tricks or something like that.
assert_eq!(src_ty.names, dst_ty.names);
let cnt = src_ty.names.len();
match FlagsSize::from_count(cnt) {
FlagsSize::Size1 => {
let mask = if cnt == 8 { 0xff } else { (1 << cnt) - 1 };
self.convert_u8_mask(src, dst, mask);
}
FlagsSize::Size2 => {
let mask = if cnt == 16 { 0xffff } else { (1 << cnt) - 1 };
self.convert_u16_mask(src, dst, mask);
}
FlagsSize::Size4Plus(n) => {
let srcs = src.record_field_srcs(self.module, (0..n).map(|_| InterfaceType::U32));
let dsts = dst.record_field_dsts(self.module, (0..n).map(|_| InterfaceType::U32));
for (i, (src, dst)) in srcs.zip(dsts).enumerate() {
let mask = if i == n - 1 && (cnt % 32 != 0) {
(1 << (cnt % 32)) - 1
} else {
0xffffffff
};
self.convert_u32_mask(&src, &dst, mask);
}
}
}
}
fn translate_tuple(
&mut self,
src_ty: TypeTupleIndex,

View File

@@ -1225,3 +1225,163 @@
(instance $c2 (instantiate $c2 (with "" (instance $c1))))
)
"unreachable")
;; test that flags get their upper bits all masked off
(component
(type $f1 (flags "f1"))
(type $f8 (flags "f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8"))
(type $f9 (flags "f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8" "f9"))
(type $f16 (flags
"f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8"
"g1" "g2" "g3" "g4" "g5" "g6" "g7" "g8"
))
(type $f17 (flags
"f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8"
"g1" "g2" "g3" "g4" "g5" "g6" "g7" "g8"
"g9"
))
(type $f32 (flags
"f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8"
"g1" "g2" "g3" "g4" "g5" "g6" "g7" "g8"
"h1" "h2" "h3" "h4" "h5" "h6" "h7" "h8"
"i1" "i2" "i3" "i4" "i5" "i6" "i7" "i8"
))
(type $f33 (flags
"f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8"
"g1" "g2" "g3" "g4" "g5" "g6" "g7" "g8"
"h1" "h2" "h3" "h4" "h5" "h6" "h7" "h8"
"i1" "i2" "i3" "i4" "i5" "i6" "i7" "i8"
"i9"
))
(type $f64 (flags
"f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8"
"g1" "g2" "g3" "g4" "g5" "g6" "g7" "g8"
"h1" "h2" "h3" "h4" "h5" "h6" "h7" "h8"
"i1" "i2" "i3" "i4" "i5" "i6" "i7" "i8"
"j1" "j2" "j3" "j4" "j5" "j6" "j7" "j8"
"k1" "k2" "k3" "k4" "k5" "k6" "k7" "k8"
"l1" "l2" "l3" "l4" "l5" "l6" "l7" "l8"
"m1" "m2" "m3" "m4" "m5" "m6" "m7" "m8"
))
(type $f65 (flags
"f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8"
"g1" "g2" "g3" "g4" "g5" "g6" "g7" "g8"
"h1" "h2" "h3" "h4" "h5" "h6" "h7" "h8"
"i1" "i2" "i3" "i4" "i5" "i6" "i7" "i8"
"j1" "j2" "j3" "j4" "j5" "j6" "j7" "j8"
"k1" "k2" "k3" "k4" "k5" "k6" "k7" "k8"
"l1" "l2" "l3" "l4" "l5" "l6" "l7" "l8"
"m1" "m2" "m3" "m4" "m5" "m6" "m7" "m8"
"m9"
))
(component $c1
(core module $m
(func (export "f1") (param i32)
(if (i32.ne (local.get 0) (i32.const 0x1)) (unreachable))
)
(func (export "f8") (param i32)
(if (i32.ne (local.get 0) (i32.const 0x11)) (unreachable))
)
(func (export "f9") (param i32)
(if (i32.ne (local.get 0) (i32.const 0x111)) (unreachable))
)
(func (export "f16") (param i32)
(if (i32.ne (local.get 0) (i32.const 0x1111)) (unreachable))
)
(func (export "f17") (param i32)
(if (i32.ne (local.get 0) (i32.const 0x11111)) (unreachable))
)
(func (export "f32") (param i32)
(if (i32.ne (local.get 0) (i32.const 0x11111111)) (unreachable))
)
(func (export "f33") (param i32 i32)
(if (i32.ne (local.get 0) (i32.const 0x11111111)) (unreachable))
(if (i32.ne (local.get 1) (i32.const 0x1)) (unreachable))
)
(func (export "f64") (param i32 i32)
(if (i32.ne (local.get 0) (i32.const 0x11111111)) (unreachable))
(if (i32.ne (local.get 1) (i32.const 0x11111111)) (unreachable))
)
(func (export "f65") (param i32 i32 i32)
(if (i32.ne (local.get 0) (i32.const 0x11111111)) (unreachable))
(if (i32.ne (local.get 1) (i32.const 0x11111111)) (unreachable))
(if (i32.ne (local.get 2) (i32.const 0x1)) (unreachable))
)
)
(core instance $m (instantiate $m))
(func (export "f1") (param $f1) (canon lift (core func $m "f1")))
(func (export "f8") (param $f8) (canon lift (core func $m "f8")))
(func (export "f9") (param $f9) (canon lift (core func $m "f9")))
(func (export "f16") (param $f16) (canon lift (core func $m "f16")))
(func (export "f17") (param $f17) (canon lift (core func $m "f17")))
(func (export "f32") (param $f32) (canon lift (core func $m "f32")))
(func (export "f33") (param $f33) (canon lift (core func $m "f33")))
(func (export "f64") (param $f64) (canon lift (core func $m "f64")))
(func (export "f65") (param $f65) (canon lift (core func $m "f65")))
)
(instance $c1 (instantiate $c1))
(component $c2
(import "" (instance $i
(export "f1" (func (param $f1)))
(export "f8" (func (param $f8)))
(export "f9" (func (param $f9)))
(export "f16" (func (param $f16)))
(export "f17" (func (param $f17)))
(export "f32" (func (param $f32)))
(export "f33" (func (param $f33)))
(export "f64" (func (param $f64)))
(export "f65" (func (param $f65)))
))
(core func $f1 (canon lower (func $i "f1")))
(core func $f8 (canon lower (func $i "f8")))
(core func $f9 (canon lower (func $i "f9")))
(core func $f16 (canon lower (func $i "f16")))
(core func $f17 (canon lower (func $i "f17")))
(core func $f32 (canon lower (func $i "f32")))
(core func $f33 (canon lower (func $i "f33")))
(core func $f64 (canon lower (func $i "f64")))
(core func $f65 (canon lower (func $i "f65")))
(core module $m
(import "" "f1" (func $f1 (param i32)))
(import "" "f8" (func $f8 (param i32)))
(import "" "f9" (func $f9 (param i32)))
(import "" "f16" (func $f16 (param i32)))
(import "" "f17" (func $f17 (param i32)))
(import "" "f32" (func $f32 (param i32)))
(import "" "f33" (func $f33 (param i32 i32)))
(import "" "f64" (func $f64 (param i32 i32)))
(import "" "f65" (func $f65 (param i32 i32 i32)))
(func $start
(call $f1 (i32.const 0xffffff01))
(call $f8 (i32.const 0xffffff11))
(call $f9 (i32.const 0xffffff11))
(call $f16 (i32.const 0xffff1111))
(call $f17 (i32.const 0xffff1111))
(call $f32 (i32.const 0x11111111))
(call $f33 (i32.const 0x11111111) (i32.const 0xffffffff))
(call $f64 (i32.const 0x11111111) (i32.const 0x11111111))
(call $f65 (i32.const 0x11111111) (i32.const 0x11111111) (i32.const 0xffffffff))
)
(start $start)
)
(core instance $m (instantiate $m
(with "" (instance
(export "f1" (func $f1))
(export "f8" (func $f8))
(export "f9" (func $f9))
(export "f16" (func $f16))
(export "f17" (func $f17))
(export "f32" (func $f32))
(export "f33" (func $f33))
(export "f64" (func $f64))
(export "f65" (func $f65))
))
))
)
(instance (instantiate $c2 (with "" (instance $c1))))
)