Cranelift: Introduce the return_call and return_call_indirect instructions (#5679)

* Cranelift: Introduce the `tail` calling convention

This is an unstable-ABI calling convention that we will eventually use to
support Wasm tail calls.

Co-Authored-By: Jamey Sharp <jsharp@fastly.com>

* Cranelift: Introduce the `return_call` and `return_call_indirect` instructions

These will be used to implement tail calls for Wasm and any other language
targeting CLIF. The `return_call_indirect` instruction differs from the Wasm
instruction of the same name by taking a native address callee rather than a
Wasm function index.

Co-Authored-By: Jamey Sharp <jsharp@fastly.com>

* Cranelift: Implement verification rules for `return_call[_indirect]`

They must:

* have the same return types between the caller and callee,
* have the same calling convention between caller and callee,
* and that calling convention must support tail calls.

Co-Authored-By: Jamey Sharp <jsharp@fastly.com>

* cargo fmt

---------

Co-authored-by: Jamey Sharp <jsharp@fastly.com>
This commit is contained in:
Nick Fitzgerald
2023-02-01 13:20:35 -08:00
committed by GitHub
parent ffbbfbffce
commit bdfb746548
8 changed files with 298 additions and 101 deletions

View File

@@ -253,6 +253,51 @@ fn define_control_flow(
.call(), .call(),
); );
ig.push(
Inst::new(
"return_call",
r#"
Direct tail call.
Tail call a function which has been declared in the preamble. The
argument types must match the function's signature, the caller and
callee calling conventions must be the same, and must be a calling
convention that supports tail calls.
This instruction is a block terminator.
"#,
&formats.call,
)
.operands_in(vec![FN, args])
.returns()
.call(),
);
ig.push(
Inst::new(
"return_call_indirect",
r#"
Indirect tail call.
Call the function pointed to by `callee` with the given arguments. The
argument types must match the function's signature, the caller and
callee calling conventions must be the same, and must be a calling
convention that supports tail calls.
This instruction is a block terminator.
Note that this is different from WebAssembly's ``tail_call_indirect``;
the callee is a native address, rather than a table index. For
WebAssembly, `table_addr` and `load` are used to obtain a native address
from a table.
"#,
&formats.call_indirect,
)
.operands_in(vec![SIG, callee, args])
.returns()
.call(),
);
let FN = &Operand::new("FN", &entities.func_ref) let FN = &Operand::new("FN", &entities.func_ref)
.with_doc("function to call, declared by `function`"); .with_doc("function to call, declared by `function`");
let addr = &Operand::new("addr", iAddr); let addr = &Operand::new("addr", iAddr);

View File

@@ -22,6 +22,7 @@ use core::u16;
use alloc::collections::BTreeMap; use alloc::collections::BTreeMap;
#[cfg(feature = "enable-serde")] #[cfg(feature = "enable-serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
/// Storage for instructions within the DFG. /// Storage for instructions within the DFG.
#[derive(Clone, PartialEq, Hash)] #[derive(Clone, PartialEq, Hash)]
@@ -797,43 +798,22 @@ impl DataFlowGraph {
where where
I: Iterator<Item = Option<Value>>, I: Iterator<Item = Option<Value>>,
{ {
let mut reuse = reuse.fuse();
self.results[inst].clear(&mut self.value_lists); self.results[inst].clear(&mut self.value_lists);
// Get the call signature if this is a function call. let mut reuse = reuse.fuse();
if let Some(sig) = self.call_signature(inst) { let result_tys: SmallVec<[_; 16]> = self.inst_result_types(inst, ctrl_typevar).collect();
// Create result values corresponding to the call return types. let num_results = result_tys.len();
debug_assert_eq!(
self.insts[inst].opcode().constraints().num_fixed_results(), for ty in result_tys {
0 if let Some(Some(v)) = reuse.next() {
); debug_assert_eq!(self.value_type(v), ty, "Reused {} is wrong type", ty);
let num_results = self.signatures[sig].returns.len(); self.attach_result(inst, v);
for res_idx in 0..num_results { } else {
let ty = self.signatures[sig].returns[res_idx].value_type; self.append_result(inst, ty);
if let Some(Some(v)) = reuse.next() {
debug_assert_eq!(self.value_type(v), ty, "Reused {} is wrong type", ty);
self.attach_result(inst, v);
} else {
self.append_result(inst, ty);
}
} }
num_results
} else {
// Create result values corresponding to the opcode's constraints.
let constraints = self.insts[inst].opcode().constraints();
let num_results = constraints.num_fixed_results();
for res_idx in 0..num_results {
let ty = constraints.result_type(res_idx, ctrl_typevar);
if let Some(Some(v)) = reuse.next() {
debug_assert_eq!(self.value_type(v), ty, "Reused {} is wrong type", ty);
self.attach_result(inst, v);
} else {
self.append_result(inst, ty);
}
}
num_results
} }
num_results
} }
/// Create a `ReplaceBuilder` that will replace `inst` with a new instruction in place. /// Create a `ReplaceBuilder` that will replace `inst` with a new instruction in place.
@@ -977,6 +957,84 @@ impl DataFlowGraph {
} }
} }
/// Like `call_signature` but returns none for tail call instructions.
fn non_tail_call_signature(&self, inst: Inst) -> Option<SigRef> {
let sig = self.call_signature(inst)?;
match self.insts[inst].opcode() {
ir::Opcode::ReturnCall | ir::Opcode::ReturnCallIndirect => None,
_ => Some(sig),
}
}
// Only for use by the verifier. Everyone else should just use
// `dfg.inst_results(inst).len()`.
pub(crate) fn num_expected_results_for_verifier(&self, inst: Inst) -> usize {
match self.non_tail_call_signature(inst) {
Some(sig) => self.signatures[sig].returns.len(),
None => {
let constraints = self.insts[inst].opcode().constraints();
constraints.num_fixed_results()
}
}
}
/// Get the result types of the given instruction.
pub fn inst_result_types<'a>(
&'a self,
inst: Inst,
ctrl_typevar: Type,
) -> impl iter::ExactSizeIterator<Item = Type> + 'a {
return match self.non_tail_call_signature(inst) {
Some(sig) => InstResultTypes::Signature(self, sig, 0),
None => {
let constraints = self.insts[inst].opcode().constraints();
InstResultTypes::Constraints(constraints, ctrl_typevar, 0)
}
};
enum InstResultTypes<'a> {
Signature(&'a DataFlowGraph, SigRef, usize),
Constraints(ir::instructions::OpcodeConstraints, Type, usize),
}
impl Iterator for InstResultTypes<'_> {
type Item = Type;
fn next(&mut self) -> Option<Type> {
match self {
InstResultTypes::Signature(dfg, sig, i) => {
let param = dfg.signatures[*sig].returns.get(*i)?;
*i += 1;
Some(param.value_type)
}
InstResultTypes::Constraints(constraints, ctrl_ty, i) => {
if *i < constraints.num_fixed_results() {
let ty = constraints.result_type(*i, *ctrl_ty);
*i += 1;
Some(ty)
} else {
None
}
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let len = match self {
InstResultTypes::Signature(dfg, sig, i) => {
dfg.signatures[*sig].returns.len() - *i
}
InstResultTypes::Constraints(constraints, _, i) => {
constraints.num_fixed_results() - *i
}
};
(len, Some(len))
}
}
impl ExactSizeIterator for InstResultTypes<'_> {}
}
/// Check if `inst` is a branch. /// Check if `inst` is a branch.
pub fn analyze_branch(&self, inst: Inst) -> BranchInfo { pub fn analyze_branch(&self, inst: Inst) -> BranchInfo {
self.insts[inst].analyze_branch() self.insts[inst].analyze_branch()
@@ -995,20 +1053,7 @@ impl DataFlowGraph {
result_idx: usize, result_idx: usize,
ctrl_typevar: Type, ctrl_typevar: Type,
) -> Option<Type> { ) -> Option<Type> {
let constraints = self.insts[inst].opcode().constraints(); self.inst_result_types(inst, ctrl_typevar).nth(result_idx)
let num_fixed_results = constraints.num_fixed_results();
if result_idx < num_fixed_results {
return Some(constraints.result_type(result_idx, ctrl_typevar));
}
// Not a fixed result, try to extract a return type from the call signature.
self.call_signature(inst).and_then(|sigref| {
self.signatures[sigref]
.returns
.get(result_idx - num_fixed_results)
.map(|&arg| arg.value_type)
})
} }
/// Get the controlling type variable, or `INVALID` if `inst` isn't polymorphic. /// Get the controlling type variable, or `INVALID` if `inst` isn't polymorphic.
@@ -1283,29 +1328,15 @@ impl DataFlowGraph {
ctrl_typevar: Type, ctrl_typevar: Type,
reuse: &[Value], reuse: &[Value],
) -> usize { ) -> usize {
// Get the call signature if this is a function call. let mut reuse_iter = reuse.iter().copied();
if let Some(sig) = self.call_signature(inst) { let result_tys: SmallVec<[_; 16]> = self.inst_result_types(inst, ctrl_typevar).collect();
assert_eq!( for ty in result_tys {
self.insts[inst].opcode().constraints().num_fixed_results(), if ty.is_dynamic_vector() {
0 self.check_dynamic_type(ty)
); .unwrap_or_else(|| panic!("Use of undeclared dynamic type: {}", ty));
for res_idx in 0..self.signatures[sig].returns.len() {
let ty = self.signatures[sig].returns[res_idx].value_type;
if let Some(v) = reuse.get(res_idx) {
self.set_value_type_for_parser(*v, ty);
}
} }
} else { if let Some(v) = reuse_iter.next() {
let constraints = self.insts[inst].opcode().constraints(); self.set_value_type_for_parser(v, ty);
for res_idx in 0..constraints.num_fixed_results() {
let ty = constraints.result_type(res_idx, ctrl_typevar);
if ty.is_dynamic_vector() {
self.check_dynamic_type(ty)
.unwrap_or_else(|| panic!("Use of undeclared dynamic type: {}", ty));
}
if let Some(v) = reuse.get(res_idx) {
self.set_value_type_for_parser(*v, ty);
}
} }
} }

View File

@@ -14,6 +14,8 @@ pub enum CallConv {
Fast, Fast,
/// Smallest caller code size, not ABI-stable. /// Smallest caller code size, not ABI-stable.
Cold, Cold,
/// Supports tail calls, not ABI-stable.
Tail,
/// System V-style convention used on many platforms. /// System V-style convention used on many platforms.
SystemV, SystemV,
/// Windows "fastcall" convention, also used for x64 and ARM. /// Windows "fastcall" convention, also used for x64 and ARM.
@@ -64,6 +66,14 @@ impl CallConv {
} }
} }
/// Does this calling convention support tail calls?
pub fn supports_tail_calls(&self) -> bool {
match self {
CallConv::Tail => true,
_ => false,
}
}
/// Is the calling convention extending the Windows Fastcall ABI? /// Is the calling convention extending the Windows Fastcall ABI?
pub fn extends_windows_fastcall(self) -> bool { pub fn extends_windows_fastcall(self) -> bool {
match self { match self {
@@ -94,6 +104,7 @@ impl fmt::Display for CallConv {
f.write_str(match *self { f.write_str(match *self {
Self::Fast => "fast", Self::Fast => "fast",
Self::Cold => "cold", Self::Cold => "cold",
Self::Tail => "tail",
Self::SystemV => "system_v", Self::SystemV => "system_v",
Self::WindowsFastcall => "windows_fastcall", Self::WindowsFastcall => "windows_fastcall",
Self::AppleAarch64 => "apple_aarch64", Self::AppleAarch64 => "apple_aarch64",
@@ -111,6 +122,7 @@ impl str::FromStr for CallConv {
match s { match s {
"fast" => Ok(Self::Fast), "fast" => Ok(Self::Fast),
"cold" => Ok(Self::Cold), "cold" => Ok(Self::Cold),
"tail" => Ok(Self::Tail),
"system_v" => Ok(Self::SystemV), "system_v" => Ok(Self::SystemV),
"windows_fastcall" => Ok(Self::WindowsFastcall), "windows_fastcall" => Ok(Self::WindowsFastcall),
"apple_aarch64" => Ok(Self::AppleAarch64), "apple_aarch64" => Ok(Self::AppleAarch64),

View File

@@ -708,6 +708,7 @@ impl ABIMachineSpec for X64ABIMachineSpec {
regs: &[Writable<RealReg>], regs: &[Writable<RealReg>],
) -> Vec<Writable<RealReg>> { ) -> Vec<Writable<RealReg>> {
let mut regs: Vec<Writable<RealReg>> = match call_conv { let mut regs: Vec<Writable<RealReg>> = match call_conv {
CallConv::Tail => unimplemented!(),
CallConv::Fast | CallConv::Cold | CallConv::SystemV | CallConv::WasmtimeSystemV => regs CallConv::Fast | CallConv::Cold | CallConv::SystemV | CallConv::WasmtimeSystemV => regs
.iter() .iter()
.cloned() .cloned()
@@ -823,6 +824,7 @@ fn get_intreg_for_retval(
retval_idx: usize, retval_idx: usize,
) -> Option<Reg> { ) -> Option<Reg> {
match call_conv { match call_conv {
CallConv::Tail => unimplemented!(),
CallConv::Fast | CallConv::Cold | CallConv::SystemV => match intreg_idx { CallConv::Fast | CallConv::Cold | CallConv::SystemV => match intreg_idx {
0 => Some(regs::rax()), 0 => Some(regs::rax()),
1 => Some(regs::rdx()), 1 => Some(regs::rdx()),
@@ -851,6 +853,7 @@ fn get_fltreg_for_retval(
retval_idx: usize, retval_idx: usize,
) -> Option<Reg> { ) -> Option<Reg> {
match call_conv { match call_conv {
CallConv::Tail => unimplemented!(),
CallConv::Fast | CallConv::Cold | CallConv::SystemV => match fltreg_idx { CallConv::Fast | CallConv::Cold | CallConv::SystemV => match fltreg_idx {
0 => Some(regs::xmm0()), 0 => Some(regs::xmm0()),
1 => Some(regs::xmm1()), 1 => Some(regs::xmm1()),

View File

@@ -532,23 +532,15 @@ impl<'a> Verifier<'a> {
)); ));
} }
let num_fixed_results = inst_data.opcode().constraints().num_fixed_results(); let expected_num_results = dfg.num_expected_results_for_verifier(inst);
// var_results is 0 if we aren't a call instruction
let var_results = dfg
.call_signature(inst)
.map_or(0, |sig| dfg.signatures[sig].returns.len());
let total_results = num_fixed_results + var_results;
// All result values for multi-valued instructions are created // All result values for multi-valued instructions are created
let got_results = dfg.inst_results(inst).len(); let got_results = dfg.inst_results(inst).len();
if got_results != total_results { if got_results != expected_num_results {
return errors.fatal(( return errors.fatal((
inst, inst,
self.context(inst), self.context(inst),
format!( format!("expected {expected_num_results} result values, found {got_results}"),
"expected {} result values, found {}",
total_results, got_results,
),
)); ));
} }
@@ -1426,29 +1418,91 @@ impl<'a> Verifier<'a> {
} }
fn typecheck_return(&self, inst: Inst, errors: &mut VerifierErrors) -> VerifierStepResult<()> { fn typecheck_return(&self, inst: Inst, errors: &mut VerifierErrors) -> VerifierStepResult<()> {
if self.func.dfg.insts[inst].opcode().is_return() { match self.func.dfg.insts[inst] {
let args = self.func.dfg.inst_variable_args(inst); ir::InstructionData::MultiAry {
let expected_types = &self.func.signature.returns; opcode: Opcode::Return,
if args.len() != expected_types.len() { args,
return errors.nonfatal(( } => {
let types = args
.as_slice(&self.func.dfg.value_lists)
.iter()
.map(|v| self.func.dfg.value_type(*v));
self.typecheck_return_types(
inst,
types,
errors,
"arguments of return must match function signature",
)?;
}
ir::InstructionData::Call {
opcode: Opcode::ReturnCall,
func_ref,
..
} => {
let sig_ref = self.func.dfg.ext_funcs[func_ref].signature;
self.typecheck_tail_call(inst, sig_ref, errors)?;
}
ir::InstructionData::CallIndirect {
opcode: Opcode::ReturnCallIndirect,
sig_ref,
..
} => {
self.typecheck_tail_call(inst, sig_ref, errors)?;
}
inst => debug_assert!(!inst.opcode().is_return()),
}
Ok(())
}
fn typecheck_tail_call(
&self,
inst: Inst,
sig_ref: SigRef,
errors: &mut VerifierErrors,
) -> VerifierStepResult<()> {
let signature = &self.func.dfg.signatures[sig_ref];
let cc = signature.call_conv;
if !cc.supports_tail_calls() {
errors.report((
inst,
self.context(inst),
format!("calling convention `{cc}` does not support tail calls"),
));
}
if cc != self.func.signature.call_conv {
errors.report((
inst,
self.context(inst),
"callee's calling convention must match caller",
));
}
let types = signature.returns.iter().map(|param| param.value_type);
self.typecheck_return_types(inst, types, errors, "results of callee must match caller")?;
Ok(())
}
fn typecheck_return_types(
&self,
inst: Inst,
actual_types: impl ExactSizeIterator<Item = Type>,
errors: &mut VerifierErrors,
message: &str,
) -> VerifierStepResult<()> {
let expected_types = &self.func.signature.returns;
if actual_types.len() != expected_types.len() {
return errors.nonfatal((inst, self.context(inst), message));
}
for (i, (actual_type, &expected_type)) in actual_types.zip(expected_types).enumerate() {
if actual_type != expected_type.value_type {
errors.report((
inst, inst,
self.context(inst), self.context(inst),
"arguments of return must match function signature", format!(
"result {i} has type {actual_type}, must match function signature of \
{expected_type}"
),
)); ));
} }
for (i, (&arg, &expected_type)) in args.iter().zip(expected_types).enumerate() {
let arg_type = self.func.dfg.value_type(arg);
if arg_type != expected_type.value_type {
errors.report((
inst,
self.context(inst),
format!(
"arg {} ({}) has type {}, must match function signature of {}",
i, arg, arg_type, expected_type
),
));
}
}
} }
Ok(()) Ok(())
} }

View File

@@ -0,0 +1,50 @@
test verifier
function %test_1(i32) -> i32 tail { ; Ok
fn0 = %wow(i32) -> i32 tail
block0(v0: i32):
return_call fn0(v0)
}
function %test_2(i32) -> i32 fast {
fn0 = %wow(i32) -> i32 tail
block0(v0: i32):
return_call fn0(v0) ; error: callee's calling convention must match caller
}
function %test_3(i32) -> i32 tail {
fn0 = %wow(i32) -> i32 fast
block0(v0: i32):
return_call fn0(v0) ; error: calling convention `fast` does not support tail calls
; error: callee's calling convention must match caller
}
function %test_4(i32) -> i32 system_v {
fn0 = %wow(i32) -> i32 system_v
block0(v0: i32):
return_call fn0(v0) ; error: calling convention `system_v` does not support tail calls
}
function %test_5(i32) tail {
fn0 = %wow(i32) -> i32 tail
block0(v0: i32):
return_call fn0(v0) ; error: results of callee must match caller
}
function %test_6(i32) -> i32 tail {
fn0 = %wow(i32) tail
block0(v0: i32):
return_call fn0(v0) ; error: results of callee must match caller
}
function %test_7(i32) -> i32 tail {
fn0 = %wow(i32) -> i64 tail
block0(v0: i32):
return_call fn0(v0) ; error: result 0 has type i64, must match function signature of i32
}
function %test_8(i32) -> i32 tail {
fn0 = %wow(i32) -> i32 tail
block0(v0: i32):
return_call fn0() ; error: mismatched argument count for `return_call fn0()`: got 0, expected 1
}

View File

@@ -19,7 +19,7 @@ function %incorrect_arg_type(i32, i8) -> i32 {
function %incorrect_return_type() -> f32 { function %incorrect_return_type() -> f32 {
block0: block0:
v0 = iconst.i32 1 v0 = iconst.i32 1
return v0 ; error: arg 0 (v0) has type i32, must match function signature of f32 return v0 ; error: result 0 has type i32, must match function signature of f32
} }
function %too_many_return_values() { function %too_many_return_values() {
@@ -82,7 +82,7 @@ function %jump_args() {
v0 = iconst.i16 10 v0 = iconst.i16 10
v3 = iconst.i64 20 v3 = iconst.i64 20
jump block1(v0, v3) ; error: arg 0 (v0) has type i16, expected i64 jump block1(v0, v3) ; error: arg 0 (v0) has type i16, expected i64
; error: arg 1 (v3) has type i64, expected i16 ; error: arg 1 (v3) has type i64, expected i16
block1(v10: i64, v11: i16): block1(v10: i64, v11: i16):
return return
} }

View File

@@ -391,6 +391,8 @@ where
} }
} }
Opcode::CallIndirect => unimplemented!("CallIndirect"), Opcode::CallIndirect => unimplemented!("CallIndirect"),
Opcode::ReturnCall => unimplemented!("ReturnCall"),
Opcode::ReturnCallIndirect => unimplemented!("ReturnCallIndirect"),
Opcode::FuncAddr => unimplemented!("FuncAddr"), Opcode::FuncAddr => unimplemented!("FuncAddr"),
Opcode::Load Opcode::Load
| Opcode::Uload8 | Opcode::Uload8