components: Fix support for 0-sized flags (#4560)
This commit goes through and updates support in the various argument passing routines to support 0-sized flags. A bit of a degenerate case but clarified in WebAssembly/component-model#76 as intentional.
This commit is contained in:
@@ -352,6 +352,7 @@ fn expand_record_for_component_type(
|
|||||||
#[repr(C)]
|
#[repr(C)]
|
||||||
pub struct #lower <#lower_generic_params> {
|
pub struct #lower <#lower_generic_params> {
|
||||||
#lower_field_declarations
|
#lower_field_declarations
|
||||||
|
_align: [wasmtime::ValRaw; 0],
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl #impl_generics wasmtime::component::ComponentType for #name #ty_generics #where_clause {
|
unsafe impl #impl_generics wasmtime::component::ComponentType for #name #ty_generics #where_clause {
|
||||||
@@ -965,6 +966,10 @@ fn expand_flags(flags: &Flags) -> Result<TokenStream> {
|
|||||||
let count = flags.flags.len();
|
let count = flags.flags.len();
|
||||||
|
|
||||||
match size {
|
match size {
|
||||||
|
FlagsSize::Size0 => {
|
||||||
|
ty = quote!(());
|
||||||
|
eq = quote!(true);
|
||||||
|
}
|
||||||
FlagsSize::Size1 => {
|
FlagsSize::Size1 => {
|
||||||
ty = quote!(u8);
|
ty = quote!(u8);
|
||||||
|
|
||||||
@@ -1021,6 +1026,17 @@ fn expand_flags(flags: &Flags) -> Result<TokenStream> {
|
|||||||
let mut not;
|
let mut not;
|
||||||
|
|
||||||
match size {
|
match size {
|
||||||
|
FlagsSize::Size0 => {
|
||||||
|
count = 0;
|
||||||
|
as_array = quote!([]);
|
||||||
|
bitor = quote!(Self {});
|
||||||
|
bitor_assign = quote!();
|
||||||
|
bitand = quote!(Self {});
|
||||||
|
bitand_assign = quote!();
|
||||||
|
bitxor = quote!(Self {});
|
||||||
|
bitxor_assign = quote!();
|
||||||
|
not = quote!(Self {});
|
||||||
|
}
|
||||||
FlagsSize::Size1 | FlagsSize::Size2 => {
|
FlagsSize::Size1 | FlagsSize::Size2 => {
|
||||||
count = 1;
|
count = 1;
|
||||||
as_array = quote!([self.__inner0 as u32]);
|
as_array = quote!([self.__inner0 as u32]);
|
||||||
@@ -1085,6 +1101,7 @@ fn expand_flags(flags: &Flags) -> Result<TokenStream> {
|
|||||||
component_names.extend(quote!(#component_name,));
|
component_names.extend(quote!(#component_name,));
|
||||||
|
|
||||||
let fields = match size {
|
let fields = match size {
|
||||||
|
FlagsSize::Size0 => quote!(),
|
||||||
FlagsSize::Size1 => {
|
FlagsSize::Size1 => {
|
||||||
let init = 1_u8 << index;
|
let init = 1_u8 << index;
|
||||||
quote!(__inner0: #init)
|
quote!(__inner0: #init)
|
||||||
|
|||||||
@@ -48,6 +48,8 @@ impl From<DiscriminantSize> for usize {
|
|||||||
|
|
||||||
/// Represents the number of bytes required to store a flags value in the component model
|
/// Represents the number of bytes required to store a flags value in the component model
|
||||||
pub enum FlagsSize {
|
pub enum FlagsSize {
|
||||||
|
/// There are no flags
|
||||||
|
Size0,
|
||||||
/// Flags can fit in a u8
|
/// Flags can fit in a u8
|
||||||
Size1,
|
Size1,
|
||||||
/// Flags can fit in a u16
|
/// Flags can fit in a u16
|
||||||
@@ -59,7 +61,9 @@ pub enum FlagsSize {
|
|||||||
impl FlagsSize {
|
impl FlagsSize {
|
||||||
/// Calculate the size needed to represent a value with the specified number of flags.
|
/// Calculate the size needed to represent a value with the specified number of flags.
|
||||||
pub fn from_count(count: usize) -> FlagsSize {
|
pub fn from_count(count: usize) -> FlagsSize {
|
||||||
if count <= 8 {
|
if count == 0 {
|
||||||
|
FlagsSize::Size0
|
||||||
|
} else if count <= 8 {
|
||||||
FlagsSize::Size1
|
FlagsSize::Size1
|
||||||
} else if count <= 16 {
|
} else if count <= 16 {
|
||||||
FlagsSize::Size2
|
FlagsSize::Size2
|
||||||
|
|||||||
@@ -53,10 +53,8 @@ enum ValType {
|
|||||||
Float64,
|
Float64,
|
||||||
Char,
|
Char,
|
||||||
Record(Vec<ValType>),
|
Record(Vec<ValType>),
|
||||||
// FIXME(WebAssembly/component-model#75) are zero-sized flags allowed?
|
// Up to 65 flags to exercise up to 3 u32 values
|
||||||
//
|
Flags(UsizeInRange<0, 65>),
|
||||||
// ... otherwise go up to 65 flags to exercise up to 3 u32 values
|
|
||||||
Flags(UsizeInRange<1, 65>),
|
|
||||||
Tuple(Vec<ValType>),
|
Tuple(Vec<ValType>),
|
||||||
Variant(NonZeroLenVec<ValType>),
|
Variant(NonZeroLenVec<ValType>),
|
||||||
Union(NonZeroLenVec<ValType>),
|
Union(NonZeroLenVec<ValType>),
|
||||||
|
|||||||
@@ -700,6 +700,7 @@ impl Compiler<'_, '_> {
|
|||||||
assert_eq!(src_ty.names, dst_ty.names);
|
assert_eq!(src_ty.names, dst_ty.names);
|
||||||
let cnt = src_ty.names.len();
|
let cnt = src_ty.names.len();
|
||||||
match FlagsSize::from_count(cnt) {
|
match FlagsSize::from_count(cnt) {
|
||||||
|
FlagsSize::Size0 => {}
|
||||||
FlagsSize::Size1 => {
|
FlagsSize::Size1 => {
|
||||||
let mask = if cnt == 8 { 0xff } else { (1 << cnt) - 1 };
|
let mask = if cnt == 8 { 0xff } else { (1 << cnt) - 1 };
|
||||||
self.convert_u8_mask(src, dst, mask);
|
self.convert_u8_mask(src, dst, mask);
|
||||||
|
|||||||
@@ -606,6 +606,10 @@ impl Type {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Type::Flags(handle) => match FlagsSize::from_count(handle.names().len()) {
|
Type::Flags(handle) => match FlagsSize::from_count(handle.names().len()) {
|
||||||
|
FlagsSize::Size0 => SizeAndAlignment {
|
||||||
|
size: 0,
|
||||||
|
alignment: 1,
|
||||||
|
},
|
||||||
FlagsSize::Size1 => SizeAndAlignment {
|
FlagsSize::Size1 => SizeAndAlignment {
|
||||||
size: 1,
|
size: 1,
|
||||||
alignment: 1,
|
alignment: 1,
|
||||||
|
|||||||
@@ -699,6 +699,7 @@ impl Val {
|
|||||||
ty: handle.clone(),
|
ty: handle.clone(),
|
||||||
count: u32::try_from(handle.names().len())?,
|
count: u32::try_from(handle.names().len())?,
|
||||||
value: match FlagsSize::from_count(handle.names().len()) {
|
value: match FlagsSize::from_count(handle.names().len()) {
|
||||||
|
FlagsSize::Size0 => Box::new([]),
|
||||||
FlagsSize::Size1 => iter::once(u8::load(mem, bytes)? as u32).collect(),
|
FlagsSize::Size1 => iter::once(u8::load(mem, bytes)? as u32).collect(),
|
||||||
FlagsSize::Size2 => iter::once(u16::load(mem, bytes)? as u32).collect(),
|
FlagsSize::Size2 => iter::once(u16::load(mem, bytes)? as u32).collect(),
|
||||||
FlagsSize::Size4Plus(n) => (0..n)
|
FlagsSize::Size4Plus(n) => (0..n)
|
||||||
@@ -850,6 +851,7 @@ impl Val {
|
|||||||
|
|
||||||
Val::Flags(Flags { count, value, .. }) => {
|
Val::Flags(Flags { count, value, .. }) => {
|
||||||
match FlagsSize::from_count(*count as usize) {
|
match FlagsSize::from_count(*count as usize) {
|
||||||
|
FlagsSize::Size0 => {}
|
||||||
FlagsSize::Size1 => u8::try_from(value[0]).unwrap().store(mem, offset)?,
|
FlagsSize::Size1 => u8::try_from(value[0]).unwrap().store(mem, offset)?,
|
||||||
FlagsSize::Size2 => u16::try_from(value[0]).unwrap().store(mem, offset)?,
|
FlagsSize::Size2 => u16::try_from(value[0]).unwrap().store(mem, offset)?,
|
||||||
FlagsSize::Size4Plus(_) => {
|
FlagsSize::Size4Plus(_) => {
|
||||||
@@ -1018,6 +1020,7 @@ fn lower_list<T>(
|
|||||||
/// Note that this will always return at least 1, even if the `count` parameter is zero.
|
/// Note that this will always return at least 1, even if the `count` parameter is zero.
|
||||||
pub(crate) fn u32_count_for_flag_count(count: usize) -> usize {
|
pub(crate) fn u32_count_for_flag_count(count: usize) -> usize {
|
||||||
match FlagsSize::from_count(count) {
|
match FlagsSize::from_count(count) {
|
||||||
|
FlagsSize::Size0 => 0,
|
||||||
FlagsSize::Size1 | FlagsSize::Size2 => 1,
|
FlagsSize::Size1 | FlagsSize::Size2 => 1,
|
||||||
FlagsSize::Size4Plus(n) => n,
|
FlagsSize::Size4Plus(n) => n,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -210,7 +210,9 @@ fn make_echo_component(type_definition: &str, type_size: u32) -> String {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn make_echo_component_with_params(type_definition: &str, params: &[Param]) -> String {
|
fn make_echo_component_with_params(type_definition: &str, params: &[Param]) -> String {
|
||||||
let func = if params.len() == 1 || params.len() > 16 {
|
let func = if params.len() == 0 {
|
||||||
|
format!("(func (export \"echo\"))")
|
||||||
|
} else if params.len() == 1 || params.len() > 16 {
|
||||||
let primitive = if params.len() == 1 {
|
let primitive = if params.len() == 1 {
|
||||||
params[0].0.primitive()
|
params[0].0.primitive()
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -426,6 +426,22 @@ fn enum_derive() -> Result<()> {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn flags() -> Result<()> {
|
fn flags() -> Result<()> {
|
||||||
|
let engine = super::engine();
|
||||||
|
let mut store = Store::new(&engine, ());
|
||||||
|
|
||||||
|
// Edge case of 0 flags
|
||||||
|
wasmtime::component::flags! {
|
||||||
|
Flags0 {}
|
||||||
|
}
|
||||||
|
assert_eq!(Flags0::default(), Flags0::default());
|
||||||
|
|
||||||
|
let component = Component::new(&engine, make_echo_component(r#"(flags)"#, 0))?;
|
||||||
|
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
|
||||||
|
let func = instance.get_typed_func::<(Flags0,), Flags0, _>(&mut store, "echo")?;
|
||||||
|
let output = func.call_and_post_return(&mut store, (Flags0::default(),))?;
|
||||||
|
assert_eq!(output, Flags0::default());
|
||||||
|
|
||||||
|
// Simple 8-bit flags
|
||||||
wasmtime::component::flags! {
|
wasmtime::component::flags! {
|
||||||
Foo {
|
Foo {
|
||||||
#[component(name = "foo-bar-baz")]
|
#[component(name = "foo-bar-baz")]
|
||||||
@@ -442,9 +458,6 @@ fn flags() -> Result<()> {
|
|||||||
assert_eq!(Foo::default(), Foo::A ^ Foo::A);
|
assert_eq!(Foo::default(), Foo::A ^ Foo::A);
|
||||||
assert_eq!(Foo::B | Foo::C, !Foo::A);
|
assert_eq!(Foo::B | Foo::C, !Foo::A);
|
||||||
|
|
||||||
let engine = super::engine();
|
|
||||||
let mut store = Store::new(&engine, ());
|
|
||||||
|
|
||||||
// Happy path: component type matches flag count and names
|
// Happy path: component type matches flag count and names
|
||||||
|
|
||||||
let component = Component::new(
|
let component = Component::new(
|
||||||
|
|||||||
@@ -1228,6 +1228,7 @@
|
|||||||
|
|
||||||
;; test that flags get their upper bits all masked off
|
;; test that flags get their upper bits all masked off
|
||||||
(component
|
(component
|
||||||
|
(type $f0 (flags))
|
||||||
(type $f1 (flags "f1"))
|
(type $f1 (flags "f1"))
|
||||||
(type $f8 (flags "f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8"))
|
(type $f8 (flags "f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8"))
|
||||||
(type $f9 (flags "f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8" "f9"))
|
(type $f9 (flags "f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8" "f9"))
|
||||||
@@ -1277,6 +1278,7 @@
|
|||||||
|
|
||||||
(component $c1
|
(component $c1
|
||||||
(core module $m
|
(core module $m
|
||||||
|
(func (export "f0"))
|
||||||
(func (export "f1") (param i32)
|
(func (export "f1") (param i32)
|
||||||
(if (i32.ne (local.get 0) (i32.const 0x1)) (unreachable))
|
(if (i32.ne (local.get 0) (i32.const 0x1)) (unreachable))
|
||||||
)
|
)
|
||||||
@@ -1310,6 +1312,7 @@
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
(core instance $m (instantiate $m))
|
(core instance $m (instantiate $m))
|
||||||
|
(func (export "f0") (param $f0) (canon lift (core func $m "f0")))
|
||||||
(func (export "f1") (param $f1) (canon lift (core func $m "f1")))
|
(func (export "f1") (param $f1) (canon lift (core func $m "f1")))
|
||||||
(func (export "f8") (param $f8) (canon lift (core func $m "f8")))
|
(func (export "f8") (param $f8) (canon lift (core func $m "f8")))
|
||||||
(func (export "f9") (param $f9) (canon lift (core func $m "f9")))
|
(func (export "f9") (param $f9) (canon lift (core func $m "f9")))
|
||||||
@@ -1324,6 +1327,7 @@
|
|||||||
|
|
||||||
(component $c2
|
(component $c2
|
||||||
(import "" (instance $i
|
(import "" (instance $i
|
||||||
|
(export "f0" (func (param $f0)))
|
||||||
(export "f1" (func (param $f1)))
|
(export "f1" (func (param $f1)))
|
||||||
(export "f8" (func (param $f8)))
|
(export "f8" (func (param $f8)))
|
||||||
(export "f9" (func (param $f9)))
|
(export "f9" (func (param $f9)))
|
||||||
@@ -1334,6 +1338,7 @@
|
|||||||
(export "f64" (func (param $f64)))
|
(export "f64" (func (param $f64)))
|
||||||
(export "f65" (func (param $f65)))
|
(export "f65" (func (param $f65)))
|
||||||
))
|
))
|
||||||
|
(core func $f0 (canon lower (func $i "f0")))
|
||||||
(core func $f1 (canon lower (func $i "f1")))
|
(core func $f1 (canon lower (func $i "f1")))
|
||||||
(core func $f8 (canon lower (func $i "f8")))
|
(core func $f8 (canon lower (func $i "f8")))
|
||||||
(core func $f9 (canon lower (func $i "f9")))
|
(core func $f9 (canon lower (func $i "f9")))
|
||||||
@@ -1345,6 +1350,7 @@
|
|||||||
(core func $f65 (canon lower (func $i "f65")))
|
(core func $f65 (canon lower (func $i "f65")))
|
||||||
|
|
||||||
(core module $m
|
(core module $m
|
||||||
|
(import "" "f0" (func $f0))
|
||||||
(import "" "f1" (func $f1 (param i32)))
|
(import "" "f1" (func $f1 (param i32)))
|
||||||
(import "" "f8" (func $f8 (param i32)))
|
(import "" "f8" (func $f8 (param i32)))
|
||||||
(import "" "f9" (func $f9 (param i32)))
|
(import "" "f9" (func $f9 (param i32)))
|
||||||
@@ -1356,6 +1362,7 @@
|
|||||||
(import "" "f65" (func $f65 (param i32 i32 i32)))
|
(import "" "f65" (func $f65 (param i32 i32 i32)))
|
||||||
|
|
||||||
(func $start
|
(func $start
|
||||||
|
(call $f0)
|
||||||
(call $f1 (i32.const 0xffffff01))
|
(call $f1 (i32.const 0xffffff01))
|
||||||
(call $f8 (i32.const 0xffffff11))
|
(call $f8 (i32.const 0xffffff11))
|
||||||
(call $f9 (i32.const 0xffffff11))
|
(call $f9 (i32.const 0xffffff11))
|
||||||
@@ -1371,6 +1378,7 @@
|
|||||||
)
|
)
|
||||||
(core instance $m (instantiate $m
|
(core instance $m (instantiate $m
|
||||||
(with "" (instance
|
(with "" (instance
|
||||||
|
(export "f0" (func $f0))
|
||||||
(export "f1" (func $f1))
|
(export "f1" (func $f1))
|
||||||
(export "f8" (func $f8))
|
(export "f8" (func $f8))
|
||||||
(export "f9" (func $f9))
|
(export "f9" (func $f9))
|
||||||
|
|||||||
Reference in New Issue
Block a user