diff --git a/cranelift/codegen/src/ir/dfg.rs b/cranelift/codegen/src/ir/dfg.rs index 9e8634b741..7d453836f5 100644 --- a/cranelift/codegen/src/ir/dfg.rs +++ b/cranelift/codegen/src/ir/dfg.rs @@ -14,6 +14,7 @@ use crate::isa::TargetIsa; use crate::packed_option::ReservedValue; use crate::write::write_operands; use crate::HashMap; +use alloc::vec::Vec; use core::fmt; use core::iter; use core::mem; @@ -776,6 +777,14 @@ impl DataFlowGraph { self.ebbs[ebb].params.as_slice(&self.value_lists) } + /// Get the types of the parameters on `ebb`. + pub fn ebb_param_types(&self, ebb: Ebb) -> Vec { + self.ebb_params(ebb) + .iter() + .map(|&v| self.value_type(v)) + .collect() + } + /// Append a parameter with type `ty` to `ebb`. pub fn append_ebb_param(&mut self, ebb: Ebb, ty: Type) -> Value { let param = self.values.next_key(); diff --git a/cranelift/codegen/src/ir/extfunc.rs b/cranelift/codegen/src/ir/extfunc.rs index 9274efe9b9..ab79bb5ea0 100644 --- a/cranelift/codegen/src/ir/extfunc.rs +++ b/cranelift/codegen/src/ir/extfunc.rs @@ -88,6 +88,16 @@ impl Signature { .count() } + /// Count the number of normal parameters in a signature. + /// Exclude special-purpose parameters that represent runtime stuff and not WebAssembly + /// arguments. + pub fn num_normal_params(&self) -> usize { + self.params + .iter() + .filter(|arg| arg.purpose == ArgumentPurpose::Normal) + .count() + } + /// Does this signature take an struct return pointer parameter? pub fn uses_struct_return_param(&self) -> bool { self.uses_special_param(ArgumentPurpose::StructReturn) @@ -102,6 +112,24 @@ impl Signature { .count() > 1 } + + /// Collect the normal parameter types of the signature; see `[ArgumentPurpose::Normal]`. + pub fn param_types(&self) -> Vec { + self.params + .iter() + .filter(|ap| ap.purpose == ArgumentPurpose::Normal) + .map(|ap| ap.value_type) + .collect() + } + + /// Collect the normal return types of the signature; see `[ArgumentPurpose::Normal]`. + pub fn return_types(&self) -> Vec { + self.returns + .iter() + .filter(|ap| ap.purpose == ArgumentPurpose::Normal) + .map(|ap| ap.value_type) + .collect() + } } /// Wrapper type capable of displaying a `Signature` with correct register names. diff --git a/cranelift/wasm/src/code_translator.rs b/cranelift/wasm/src/code_translator.rs index e155594093..68f9cac1ba 100644 --- a/cranelift/wasm/src/code_translator.rs +++ b/cranelift/wasm/src/code_translator.rs @@ -267,12 +267,14 @@ pub fn translate_operator( } Operator::End => { let frame = state.control_stack.pop().unwrap(); + let next_ebb = frame.following_code(); if !builder.is_unreachable() || !builder.is_pristine() { let return_count = frame.num_return_values(); - builder - .ins() - .jump(frame.following_code(), state.peekn(return_count)); + let return_args = state.peekn_mut(return_count); + let next_ebb_types = builder.func.dfg.ebb_param_types(next_ebb); + bitcast_arguments(return_args, &next_ebb_types, builder); + builder.ins().jump(frame.following_code(), return_args); // You might expect that if we just finished an `if` block that // didn't have a corresponding `else` block, then we would clean // up our duplicate set of parameters that we pushed earlier @@ -280,16 +282,14 @@ pub fn translate_operator( // since we truncate the stack back to the original height // below. } - builder.switch_to_block(frame.following_code()); - builder.seal_block(frame.following_code()); + builder.switch_to_block(next_ebb); + builder.seal_block(next_ebb); // If it is a loop we also have to seal the body loop block if let ControlStackFrame::Loop { header, .. } = frame { builder.seal_block(header) } state.stack.truncate(frame.original_stack_size()); - state - .stack - .extend_from_slice(builder.ebb_params(frame.following_code())); + state.stack.extend_from_slice(builder.ebb_params(next_ebb)); } /**************************** Branch instructions ********************************* * The branch instructions all have as arguments a target nesting level, which @@ -325,9 +325,17 @@ pub fn translate_operator( }; (return_count, frame.br_destination()) }; - builder - .ins() - .jump(br_destination, state.peekn(return_count)); + + // Bitcast any vector arguments to their default type, I8X16, before jumping. + let destination_args = state.peekn_mut(return_count); + let destination_types = builder.func.dfg.ebb_param_types(br_destination); + bitcast_arguments( + destination_args, + &destination_types[..return_count], + builder, + ); + + builder.ins().jump(br_destination, destination_args); state.popn(return_count); state.reachable = false; } @@ -406,7 +414,17 @@ pub fn translate_operator( frame.set_branched_to_exit(); frame.br_destination() }; - builder.ins().jump(real_dest_ebb, state.peekn(return_count)); + + // Bitcast any vector arguments to their default type, I8X16, before jumping. + let destination_args = state.peekn_mut(return_count); + let destination_types = builder.func.dfg.ebb_param_types(real_dest_ebb); + bitcast_arguments( + destination_args, + &destination_types[..return_count], + builder, + ); + + builder.ins().jump(real_dest_ebb, destination_args); } state.popn(return_count); } @@ -420,10 +438,14 @@ pub fn translate_operator( (return_count, frame.br_destination()) }; { - let args = state.peekn(return_count); + let return_args = state.peekn_mut(return_count); + let return_types = &builder.func.signature.return_types(); + bitcast_arguments(return_args, &return_types, builder); match environ.return_mode() { - ReturnMode::NormalReturns => builder.ins().return_(args), - ReturnMode::FallthroughReturn => builder.ins().jump(br_destination, args), + ReturnMode::NormalReturns => builder.ins().return_(return_args), + ReturnMode::FallthroughReturn => { + builder.ins().jump(br_destination, return_args) + } }; } state.popn(return_count); @@ -436,11 +458,18 @@ pub fn translate_operator( ************************************************************************************/ Operator::Call { function_index } => { let (fref, num_args) = state.get_direct_func(builder.func, *function_index, environ)?; + + // Bitcast any vector arguments to their default type, I8X16, before calling. + let callee_signature = + &builder.func.dfg.signatures[builder.func.dfg.ext_funcs[fref].signature]; + let args = state.peekn_mut(num_args); + bitcast_arguments(args, &callee_signature.param_types(), builder); + let call = environ.translate_call( builder.cursor(), FuncIndex::from_u32(*function_index), fref, - state.peekn(num_args), + args, )?; let inst_results = builder.inst_results(call); debug_assert_eq!( @@ -459,6 +488,12 @@ pub fn translate_operator( let (sigref, num_args) = state.get_indirect_sig(builder.func, *index, environ)?; let table = state.get_table(builder.func, *table_index, environ)?; let callee = state.pop1(); + + // Bitcast any vector arguments to their default type, I8X16, before calling. + let callee_signature = &builder.func.dfg.signatures[sigref]; + let args = state.peekn_mut(num_args); + bitcast_arguments(args, &callee_signature.param_types(), builder); + let call = environ.translate_call_indirect( builder.cursor(), TableIndex::from_u32(*table_index), @@ -1635,6 +1670,11 @@ fn translate_br_if( ) { let val = state.pop1(); let (br_destination, inputs) = translate_br_if_args(relative_depth, state); + + // Bitcast any vector arguments to their default type, I8X16, before jumping. + let destination_types = builder.func.dfg.ebb_param_types(br_destination); + bitcast_arguments(inputs, &destination_types[..inputs.len()], builder); + builder.ins().brnz(val, br_destination, inputs); #[cfg(feature = "basic-blocks")] @@ -1649,7 +1689,7 @@ fn translate_br_if( fn translate_br_if_args( relative_depth: u32, state: &mut FuncTranslationState, -) -> (ir::Ebb, &[ir::Value]) { +) -> (ir::Ebb, &mut [ir::Value]) { let i = state.control_stack.len() - 1 - (relative_depth as usize); let (return_count, br_destination) = { let frame = &mut state.control_stack[i]; @@ -1663,7 +1703,7 @@ fn translate_br_if_args( }; (return_count, frame.br_destination()) }; - let inputs = state.peekn(return_count); + let inputs = state.peekn_mut(return_count); (br_destination, inputs) } @@ -1826,7 +1866,7 @@ fn type_of(operator: &Operator) -> Type { /// Some SIMD operations only operate on I8X16 in CLIF; this will convert them to that type by /// adding a raw_bitcast if necessary. -fn optionally_bitcast_vector( +pub fn optionally_bitcast_vector( value: Value, needed_type: Type, builder: &mut FunctionBuilder, @@ -1862,3 +1902,28 @@ fn pop2_with_bitcast( let bitcast_b = optionally_bitcast_vector(b, needed_type, builder); (bitcast_a, bitcast_b) } + +/// A helper for bitcasting a sequence of values (e.g. function arguments). If a value is a +/// vector type that does not match its expected type, this will modify the value in place to point +/// to the result of a `raw_bitcast`. This conversion is necessary to translate Wasm code that +/// uses `V128` as function parameters (or implicitly in EBB parameters) and still use specific +/// CLIF types (e.g. `I32X4`) in the function body. +pub fn bitcast_arguments( + arguments: &mut [Value], + expected_types: &[Type], + builder: &mut FunctionBuilder, +) { + assert_eq!(arguments.len(), expected_types.len()); + for (i, t) in expected_types.iter().enumerate() { + if t.is_vector() { + assert!( + builder.func.dfg.value_type(arguments[i]).is_vector(), + "unexpected type mismatch: expected {}, argument {} was actually of type {}", + t, + arguments[i], + builder.func.dfg.value_type(arguments[i]) + ); + arguments[i] = optionally_bitcast_vector(arguments[i], *t, builder) + } + } +} diff --git a/cranelift/wasm/src/func_translator.rs b/cranelift/wasm/src/func_translator.rs index 8f5d49bcdf..b6a46ecb01 100644 --- a/cranelift/wasm/src/func_translator.rs +++ b/cranelift/wasm/src/func_translator.rs @@ -4,7 +4,7 @@ //! function to Cranelift IR guided by a `FuncEnvironment` which provides information about the //! WebAssembly module and the runtime environment. -use crate::code_translator::translate_operator; +use crate::code_translator::{bitcast_arguments, translate_operator}; use crate::environ::{FuncEnvironment, ReturnMode, WasmResult}; use crate::state::{FuncTranslationState, ModuleTranslationState}; use crate::translation_utils::get_vmctx_value_label; @@ -240,7 +240,11 @@ fn parse_function_body( debug_assert!(builder.is_pristine()); if !builder.is_unreachable() { match environ.return_mode() { - ReturnMode::NormalReturns => builder.ins().return_(&state.stack), + ReturnMode::NormalReturns => { + let return_types = &builder.func.signature.return_types(); + bitcast_arguments(&mut state.stack, &return_types, builder); + builder.ins().return_(&state.stack) + } ReturnMode::FallthroughReturn => builder.ins().fallthrough_return(&state.stack), }; } diff --git a/cranelift/wasm/src/state/func_state.rs b/cranelift/wasm/src/state/func_state.rs index c768248ddd..98f75e6bc9 100644 --- a/cranelift/wasm/src/state/func_state.rs +++ b/cranelift/wasm/src/state/func_state.rs @@ -306,31 +306,40 @@ impl FuncTranslationState { (v1, v2, v3) } + /// Helper to ensure the the stack size is at least as big as `n`; note that due to + /// `debug_assert` this will not execute in non-optimized builds. + #[inline] + fn ensure_length_is_at_least(&self, n: usize) { + debug_assert!( + n <= self.stack.len(), + "attempted to access {} values but stack only has {} values", + n, + self.stack.len() + ) + } + /// Pop the top `n` values on the stack. /// /// The popped values are not returned. Use `peekn` to look at them before popping. pub(crate) fn popn(&mut self, n: usize) { - debug_assert!( - n <= self.stack.len(), - "popn({}) but stack only has {} values", - n, - self.stack.len() - ); + self.ensure_length_is_at_least(n); let new_len = self.stack.len() - n; self.stack.truncate(new_len); } /// Peek at the top `n` values on the stack in the order they were pushed. pub(crate) fn peekn(&self, n: usize) -> &[Value] { - debug_assert!( - n <= self.stack.len(), - "peekn({}) but stack only has {} values", - n, - self.stack.len() - ); + self.ensure_length_is_at_least(n); &self.stack[self.stack.len() - n..] } + /// Peek at the top `n` values on the stack in the order they were pushed. + pub(crate) fn peekn_mut(&mut self, n: usize) -> &mut [Value] { + self.ensure_length_is_at_least(n); + let len = self.stack.len(); + &mut self.stack[len - n..] + } + /// Push a block on the control stack. pub(crate) fn push_block( &mut self, @@ -465,7 +474,7 @@ impl FuncTranslationState { Occupied(entry) => Ok(*entry.get()), Vacant(entry) => { let sig = environ.make_indirect_sig(func, index)?; - Ok(*entry.insert((sig, normal_args(&func.dfg.signatures[sig])))) + Ok(*entry.insert((sig, func.dfg.signatures[sig].num_normal_params()))) } } } @@ -486,17 +495,8 @@ impl FuncTranslationState { Vacant(entry) => { let fref = environ.make_direct_func(func, index)?; let sig = func.dfg.ext_funcs[fref].signature; - Ok(*entry.insert((fref, normal_args(&func.dfg.signatures[sig])))) + Ok(*entry.insert((fref, func.dfg.signatures[sig].num_normal_params()))) } } } } - -/// Count the number of normal parameters in a signature. -/// Exclude special-purpose parameters that represent runtime stuff and not WebAssembly arguments. -fn normal_args(sig: &ir::Signature) -> usize { - sig.params - .iter() - .filter(|arg| arg.purpose == ir::ArgumentPurpose::Normal) - .count() -} diff --git a/cranelift/wasmtests/call-simd.wat b/cranelift/wasmtests/call-simd.wat new file mode 100644 index 0000000000..61834d86bd --- /dev/null +++ b/cranelift/wasmtests/call-simd.wat @@ -0,0 +1,14 @@ +(module + (func $main + (v128.const i32x4 1 2 3 4) + (v128.const i32x4 1 2 3 4) + (call $add) + drop + ) + (func $add (param $a v128) (param $b v128) (result v128) + (local.get $a) + (local.get $b) + (i32x4.add) + ) + (start $main) +) diff --git a/cranelift/wasmtests/icall-simd.wat b/cranelift/wasmtests/icall-simd.wat new file mode 100644 index 0000000000..d656b265b9 --- /dev/null +++ b/cranelift/wasmtests/icall-simd.wat @@ -0,0 +1,7 @@ +(module + (type $ft (func (param v128) (result v128))) + (func $foo (export "foo") (param i32) (param v128) (result v128) + (call_indirect (type $ft) (local.get 1) (local.get 0)) + ) + (table (;0;) 23 23 anyfunc) +)