diff --git a/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs b/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs index ba5c820ea9..5d183a82f9 100644 --- a/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs +++ b/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs @@ -16,9 +16,6 @@ use wasmparser::{Validator, WasmFeatures}; use wasmtime_environ::component::*; use wasmtime_environ::fact::Module; -// Allow inflating to 16 bits but don't go further. -const MAX_ENUM_SIZE: usize = 257; - #[derive(Arbitrary, Debug)] struct GenAdapterModule { debug: bool, @@ -56,10 +53,16 @@ enum ValType { Float64, Char, Record(Vec), + // FIXME(WebAssembly/component-model#75) are zero-sized flags allowed? + // + // ... otherwise go up to 65 flags to exercise up to 3 u32 values + Flags(UsizeInRange<1, 65>), Tuple(Vec), Variant(NonZeroLenVec), Union(NonZeroLenVec), - Enum(usize), + // at least one enum variant but no more than what's necessary to inflate to + // 16 bits to keep this reasonably sized + Enum(UsizeInRange<1, 257>), Option(Box), Expected(Box, Box), } @@ -89,6 +92,20 @@ impl fmt::Debug for NonZeroLenVec { } } +pub struct UsizeInRange(usize); + +impl<'a, const L: usize, const H: usize> Arbitrary<'a> for UsizeInRange { + fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { + Ok(UsizeInRange(u.int_in_range(L..=H)?)) + } +} + +impl fmt::Debug for UsizeInRange { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + fuzz_target!(|module: GenAdapterModule| { drop(env_logger::try_init()); @@ -228,6 +245,12 @@ fn intern(types: &mut ComponentTypesBuilder, ty: &ValType) -> InterfaceType { }; InterfaceType::Record(types.add_record_type(ty)) } + ValType::Flags(size) => { + let ty = TypeFlags { + names: (0..size.0).map(|i| format!("f{i}")).collect(), + }; + InterfaceType::Flags(types.add_flags_type(ty)) + } ValType::Tuple(tys) => { let ty = TypeTuple { types: tys.iter().map(|ty| intern(types, ty)).collect(), @@ -254,10 +277,8 @@ fn intern(types: &mut ComponentTypesBuilder, ty: &ValType) -> InterfaceType { InterfaceType::Union(types.add_union_type(ty)) } ValType::Enum(size) => { - let size = size % MAX_ENUM_SIZE; - let size = if size == 0 { 1 } else { size }; let ty = TypeEnum { - names: (0..size).map(|i| format!("c{i}")).collect(), + names: (0..size.0).map(|i| format!("c{i}")).collect(), }; InterfaceType::Enum(types.add_enum_type(ty)) } diff --git a/crates/environ/src/component/types.rs b/crates/environ/src/component/types.rs index b1b553782e..07b0ab9d7e 100644 --- a/crates/environ/src/component/types.rs +++ b/crates/environ/src/component/types.rs @@ -660,7 +660,7 @@ impl ComponentTypesBuilder { let flags = TypeFlags { names: flags.iter().map(|s| s.to_string()).collect(), }; - intern(&mut self.flags, &mut self.component_types.flags, flags) + self.add_flags_type(flags) } fn enum_type(&mut self, variants: &[&str]) -> TypeEnumIndex { @@ -699,6 +699,11 @@ impl ComponentTypesBuilder { intern(&mut self.records, &mut self.component_types.records, ty) } + /// Interns a new flags type within this type information. + pub fn add_flags_type(&mut self, ty: TypeFlags) -> TypeFlagsIndex { + intern(&mut self.flags, &mut self.component_types.flags, ty) + } + /// Interns a new tuple type within this type information. pub fn add_tuple_type(&mut self, ty: TypeTuple) -> TypeTupleIndex { intern(&mut self.tuples, &mut self.component_types.tuples, ty) diff --git a/crates/environ/src/fact/trampoline.rs b/crates/environ/src/fact/trampoline.rs index d4b2f8ed32..b50000208a 100644 --- a/crates/environ/src/fact/trampoline.rs +++ b/crates/environ/src/fact/trampoline.rs @@ -16,9 +16,9 @@ //! can be somewhat arbitrary, an intentional decision. use crate::component::{ - InterfaceType, TypeEnumIndex, TypeExpectedIndex, TypeInterfaceIndex, TypeRecordIndex, - TypeTupleIndex, TypeUnionIndex, TypeVariantIndex, FLAG_MAY_ENTER, FLAG_MAY_LEAVE, - MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, + InterfaceType, TypeEnumIndex, TypeExpectedIndex, TypeFlagsIndex, TypeInterfaceIndex, + TypeRecordIndex, TypeTupleIndex, TypeUnionIndex, TypeVariantIndex, FLAG_MAY_ENTER, + FLAG_MAY_LEAVE, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, }; use crate::fact::core_types::CoreTypes; use crate::fact::signature::{align_to, Signature}; @@ -29,7 +29,7 @@ use std::collections::HashMap; use std::mem; use std::ops::Range; use wasm_encoder::{BlockType, Encode, Instruction, Instruction::*, MemArg, ValType}; -use wasmtime_component_util::DiscriminantSize; +use wasmtime_component_util::{DiscriminantSize, FlagsSize}; struct Compiler<'a, 'b> { /// The module that the adapter will eventually be inserted into. @@ -349,6 +349,7 @@ impl Compiler<'_, '_> { InterfaceType::Float64 => self.translate_f64(src, dst_ty, dst), InterfaceType::Char => self.translate_char(src, dst_ty, dst), InterfaceType::Record(t) => self.translate_record(*t, src, dst_ty, dst), + InterfaceType::Flags(f) => self.translate_flags(*f, src, dst_ty, dst), InterfaceType::Tuple(t) => self.translate_tuple(*t, src, dst_ty, dst), InterfaceType::Variant(v) => self.translate_variant(*v, src, dst_ty, dst), InterfaceType::Union(u) => self.translate_union(*u, src, dst_ty, dst), @@ -399,15 +400,25 @@ impl Compiler<'_, '_> { fn translate_u8(&mut self, src: &Source<'_>, dst_ty: &InterfaceType, dst: &Destination) { // TODO: subtyping assert!(matches!(dst_ty, InterfaceType::U8)); + self.convert_u8_mask(src, dst, 0xff); + } + + fn convert_u8_mask(&mut self, src: &Source<'_>, dst: &Destination<'_>, mask: u8) { self.push_dst_addr(dst); + let mut needs_mask = true; match src { - Source::Memory(mem) => self.i32_load8u(mem), + Source::Memory(mem) => { + self.i32_load8u(mem); + needs_mask = mask != 0xff; + } Source::Stack(stack) => { self.stack_get(stack, ValType::I32); - self.instruction(I32Const(0xff)); - self.instruction(I32And); } } + if needs_mask { + self.instruction(I32Const(i32::from(mask))); + self.instruction(I32And); + } match dst { Destination::Memory(mem) => self.i32_store8(mem), Destination::Stack(stack) => self.stack_set(stack, ValType::I32), @@ -434,15 +445,25 @@ impl Compiler<'_, '_> { fn translate_u16(&mut self, src: &Source<'_>, dst_ty: &InterfaceType, dst: &Destination) { // TODO: subtyping assert!(matches!(dst_ty, InterfaceType::U16)); + self.convert_u16_mask(src, dst, 0xffff); + } + + fn convert_u16_mask(&mut self, src: &Source<'_>, dst: &Destination<'_>, mask: u16) { self.push_dst_addr(dst); + let mut needs_mask = true; match src { - Source::Memory(mem) => self.i32_load16u(mem), + Source::Memory(mem) => { + self.i32_load16u(mem); + needs_mask = mask != 0xffff; + } Source::Stack(stack) => { self.stack_get(stack, ValType::I32); - self.instruction(I32Const(0xffff)); - self.instruction(I32And); } } + if needs_mask { + self.instruction(I32Const(i32::from(mask))); + self.instruction(I32And); + } match dst { Destination::Memory(mem) => self.i32_store16(mem), Destination::Stack(stack) => self.stack_set(stack, ValType::I32), @@ -469,11 +490,19 @@ impl Compiler<'_, '_> { fn translate_u32(&mut self, src: &Source<'_>, dst_ty: &InterfaceType, dst: &Destination) { // TODO: subtyping assert!(matches!(dst_ty, InterfaceType::U32)); + self.convert_u32_mask(src, dst, 0xffffffff) + } + + fn convert_u32_mask(&mut self, src: &Source<'_>, dst: &Destination<'_>, mask: u32) { self.push_dst_addr(dst); match src { - Source::Memory(mem) => self.i32_load(mem), + Source::Memory(mem) => self.i32_load16u(mem), Source::Stack(stack) => self.stack_get(stack, ValType::I32), } + if mask != 0xffffffff { + self.instruction(I32Const(mask as i32)); + self.instruction(I32And); + } match dst { Destination::Memory(mem) => self.i32_store(mem), Destination::Stack(stack) => self.stack_set(stack, ValType::I32), @@ -648,6 +677,52 @@ impl Compiler<'_, '_> { } } + fn translate_flags( + &mut self, + src_ty: TypeFlagsIndex, + src: &Source<'_>, + dst_ty: &InterfaceType, + dst: &Destination, + ) { + let src_ty = &self.module.types[src_ty]; + let dst_ty = match dst_ty { + InterfaceType::Flags(r) => &self.module.types[*r], + _ => panic!("expected a record"), + }; + + // TODO: subtyping + // + // Notably this implementation does not support reordering flags from + // the source to the destination nor having more flags in the + // destination. Currently this is a copy from source to destination + // in-bulk. Otherwise reordering indices would have to have some sort of + // fancy bit twiddling tricks or something like that. + assert_eq!(src_ty.names, dst_ty.names); + let cnt = src_ty.names.len(); + match FlagsSize::from_count(cnt) { + FlagsSize::Size1 => { + let mask = if cnt == 8 { 0xff } else { (1 << cnt) - 1 }; + self.convert_u8_mask(src, dst, mask); + } + FlagsSize::Size2 => { + let mask = if cnt == 16 { 0xffff } else { (1 << cnt) - 1 }; + self.convert_u16_mask(src, dst, mask); + } + FlagsSize::Size4Plus(n) => { + let srcs = src.record_field_srcs(self.module, (0..n).map(|_| InterfaceType::U32)); + let dsts = dst.record_field_dsts(self.module, (0..n).map(|_| InterfaceType::U32)); + for (i, (src, dst)) in srcs.zip(dsts).enumerate() { + let mask = if i == n - 1 && (cnt % 32 != 0) { + (1 << (cnt % 32)) - 1 + } else { + 0xffffffff + }; + self.convert_u32_mask(&src, &dst, mask); + } + } + } + } + fn translate_tuple( &mut self, src_ty: TypeTupleIndex, diff --git a/tests/misc_testsuite/component-model/fused.wast b/tests/misc_testsuite/component-model/fused.wast index 5c19f684de..fbab3d704e 100644 --- a/tests/misc_testsuite/component-model/fused.wast +++ b/tests/misc_testsuite/component-model/fused.wast @@ -1225,3 +1225,163 @@ (instance $c2 (instantiate $c2 (with "" (instance $c1)))) ) "unreachable") + +;; test that flags get their upper bits all masked off +(component + (type $f1 (flags "f1")) + (type $f8 (flags "f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8")) + (type $f9 (flags "f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8" "f9")) + (type $f16 (flags + "f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8" + "g1" "g2" "g3" "g4" "g5" "g6" "g7" "g8" + )) + (type $f17 (flags + "f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8" + "g1" "g2" "g3" "g4" "g5" "g6" "g7" "g8" + "g9" + )) + (type $f32 (flags + "f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8" + "g1" "g2" "g3" "g4" "g5" "g6" "g7" "g8" + "h1" "h2" "h3" "h4" "h5" "h6" "h7" "h8" + "i1" "i2" "i3" "i4" "i5" "i6" "i7" "i8" + )) + (type $f33 (flags + "f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8" + "g1" "g2" "g3" "g4" "g5" "g6" "g7" "g8" + "h1" "h2" "h3" "h4" "h5" "h6" "h7" "h8" + "i1" "i2" "i3" "i4" "i5" "i6" "i7" "i8" + "i9" + )) + (type $f64 (flags + "f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8" + "g1" "g2" "g3" "g4" "g5" "g6" "g7" "g8" + "h1" "h2" "h3" "h4" "h5" "h6" "h7" "h8" + "i1" "i2" "i3" "i4" "i5" "i6" "i7" "i8" + "j1" "j2" "j3" "j4" "j5" "j6" "j7" "j8" + "k1" "k2" "k3" "k4" "k5" "k6" "k7" "k8" + "l1" "l2" "l3" "l4" "l5" "l6" "l7" "l8" + "m1" "m2" "m3" "m4" "m5" "m6" "m7" "m8" + )) + (type $f65 (flags + "f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8" + "g1" "g2" "g3" "g4" "g5" "g6" "g7" "g8" + "h1" "h2" "h3" "h4" "h5" "h6" "h7" "h8" + "i1" "i2" "i3" "i4" "i5" "i6" "i7" "i8" + "j1" "j2" "j3" "j4" "j5" "j6" "j7" "j8" + "k1" "k2" "k3" "k4" "k5" "k6" "k7" "k8" + "l1" "l2" "l3" "l4" "l5" "l6" "l7" "l8" + "m1" "m2" "m3" "m4" "m5" "m6" "m7" "m8" + "m9" + )) + + (component $c1 + (core module $m + (func (export "f1") (param i32) + (if (i32.ne (local.get 0) (i32.const 0x1)) (unreachable)) + ) + (func (export "f8") (param i32) + (if (i32.ne (local.get 0) (i32.const 0x11)) (unreachable)) + ) + (func (export "f9") (param i32) + (if (i32.ne (local.get 0) (i32.const 0x111)) (unreachable)) + ) + (func (export "f16") (param i32) + (if (i32.ne (local.get 0) (i32.const 0x1111)) (unreachable)) + ) + (func (export "f17") (param i32) + (if (i32.ne (local.get 0) (i32.const 0x11111)) (unreachable)) + ) + (func (export "f32") (param i32) + (if (i32.ne (local.get 0) (i32.const 0x11111111)) (unreachable)) + ) + (func (export "f33") (param i32 i32) + (if (i32.ne (local.get 0) (i32.const 0x11111111)) (unreachable)) + (if (i32.ne (local.get 1) (i32.const 0x1)) (unreachable)) + ) + (func (export "f64") (param i32 i32) + (if (i32.ne (local.get 0) (i32.const 0x11111111)) (unreachable)) + (if (i32.ne (local.get 1) (i32.const 0x11111111)) (unreachable)) + ) + (func (export "f65") (param i32 i32 i32) + (if (i32.ne (local.get 0) (i32.const 0x11111111)) (unreachable)) + (if (i32.ne (local.get 1) (i32.const 0x11111111)) (unreachable)) + (if (i32.ne (local.get 2) (i32.const 0x1)) (unreachable)) + ) + ) + (core instance $m (instantiate $m)) + (func (export "f1") (param $f1) (canon lift (core func $m "f1"))) + (func (export "f8") (param $f8) (canon lift (core func $m "f8"))) + (func (export "f9") (param $f9) (canon lift (core func $m "f9"))) + (func (export "f16") (param $f16) (canon lift (core func $m "f16"))) + (func (export "f17") (param $f17) (canon lift (core func $m "f17"))) + (func (export "f32") (param $f32) (canon lift (core func $m "f32"))) + (func (export "f33") (param $f33) (canon lift (core func $m "f33"))) + (func (export "f64") (param $f64) (canon lift (core func $m "f64"))) + (func (export "f65") (param $f65) (canon lift (core func $m "f65"))) + ) + (instance $c1 (instantiate $c1)) + + (component $c2 + (import "" (instance $i + (export "f1" (func (param $f1))) + (export "f8" (func (param $f8))) + (export "f9" (func (param $f9))) + (export "f16" (func (param $f16))) + (export "f17" (func (param $f17))) + (export "f32" (func (param $f32))) + (export "f33" (func (param $f33))) + (export "f64" (func (param $f64))) + (export "f65" (func (param $f65))) + )) + (core func $f1 (canon lower (func $i "f1"))) + (core func $f8 (canon lower (func $i "f8"))) + (core func $f9 (canon lower (func $i "f9"))) + (core func $f16 (canon lower (func $i "f16"))) + (core func $f17 (canon lower (func $i "f17"))) + (core func $f32 (canon lower (func $i "f32"))) + (core func $f33 (canon lower (func $i "f33"))) + (core func $f64 (canon lower (func $i "f64"))) + (core func $f65 (canon lower (func $i "f65"))) + + (core module $m + (import "" "f1" (func $f1 (param i32))) + (import "" "f8" (func $f8 (param i32))) + (import "" "f9" (func $f9 (param i32))) + (import "" "f16" (func $f16 (param i32))) + (import "" "f17" (func $f17 (param i32))) + (import "" "f32" (func $f32 (param i32))) + (import "" "f33" (func $f33 (param i32 i32))) + (import "" "f64" (func $f64 (param i32 i32))) + (import "" "f65" (func $f65 (param i32 i32 i32))) + + (func $start + (call $f1 (i32.const 0xffffff01)) + (call $f8 (i32.const 0xffffff11)) + (call $f9 (i32.const 0xffffff11)) + (call $f16 (i32.const 0xffff1111)) + (call $f17 (i32.const 0xffff1111)) + (call $f32 (i32.const 0x11111111)) + (call $f33 (i32.const 0x11111111) (i32.const 0xffffffff)) + (call $f64 (i32.const 0x11111111) (i32.const 0x11111111)) + (call $f65 (i32.const 0x11111111) (i32.const 0x11111111) (i32.const 0xffffffff)) + ) + + (start $start) + ) + (core instance $m (instantiate $m + (with "" (instance + (export "f1" (func $f1)) + (export "f8" (func $f8)) + (export "f9" (func $f9)) + (export "f16" (func $f16)) + (export "f17" (func $f17)) + (export "f32" (func $f32)) + (export "f33" (func $f33)) + (export "f64" (func $f64)) + (export "f65" (func $f65)) + )) + )) + ) + (instance (instantiate $c2 (with "" (instance $c1)))) +)