From e9095050bef2d9f0004c667af07345c582efac6f Mon Sep 17 00:00:00 2001 From: Afonso Bordado Date: Sat, 25 Feb 2023 13:16:59 +0000 Subject: [PATCH] cranelift-interpreter: Implement `call_indirect` and `return_call_indirect` (#5877) * cranelift-interpreter: Implement `call_indirect` * cranelift: Fix typo * riscv64: Enable `call_indirect` tests --- .../filetests/runtests/call_indirect.clif | 2 + .../runtests/return-call-indirect.clif | 76 ++++++++++ cranelift/interpreter/src/address.rs | 54 ++++--- cranelift/interpreter/src/interpreter.rs | 65 +++++++- cranelift/interpreter/src/state.rs | 54 ++++++- cranelift/interpreter/src/step.rs | 139 ++++++++++++------ 6 files changed, 321 insertions(+), 69 deletions(-) create mode 100644 cranelift/filetests/filetests/runtests/return-call-indirect.clif diff --git a/cranelift/filetests/filetests/runtests/call_indirect.clif b/cranelift/filetests/filetests/runtests/call_indirect.clif index 3705001c98..6a71e492e5 100644 --- a/cranelift/filetests/filetests/runtests/call_indirect.clif +++ b/cranelift/filetests/filetests/runtests/call_indirect.clif @@ -1,9 +1,11 @@ +test interpret test run target x86_64 target aarch64 target aarch64 sign_return_address target aarch64 has_pauth sign_return_address target s390x +target riscv64gc function %callee_indirect(i64) -> i64 { diff --git a/cranelift/filetests/filetests/runtests/return-call-indirect.clif b/cranelift/filetests/filetests/runtests/return-call-indirect.clif new file mode 100644 index 0000000000..4277aaa288 --- /dev/null +++ b/cranelift/filetests/filetests/runtests/return-call-indirect.clif @@ -0,0 +1,76 @@ +test interpret +;; test run +;; target x86_64 +;; target aarch64 +;; target aarch64 sign_return_address +;; target aarch64 has_pauth sign_return_address +;; target s390x + +;;;; Test passing `i64`s ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + +function %callee_i64(i64) -> i64 tail { +block0(v0: i64): + v1 = iadd_imm.i64 v0, 10 + return v1 +} + +function %call_i64(i64) -> i64 tail { + fn0 = %callee_i64(i64) -> i64 tail + ; sig0 = (i64) -> i64 tail + +block0(v0: i64): + v1 = func_addr.i64 fn0 + return_call_indirect sig0, v1(v0) +} +; run: %call_i64(10) == 20 + +;;;; Test colocated tail calls ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + +function %colocated_i64(i64) -> i64 tail { + fn0 = colocated %callee_i64(i64) -> i64 tail + ; sig0 = (i64) -> i64 tail + +block0(v0: i64): + v1 = func_addr.i64 fn0 + return_call_indirect sig0, v1(v0) +} +; run: %colocated_i64(10) == 20 + +;;;; Test passing `f64`s ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + +function %callee_f64(f64) -> f64 tail { +block0(v0: f64): + v1 = f64const 0x10.0 + v2 = fadd.f64 v0, v1 + return v2 +} + +function %call_f64(f64) -> f64 tail { + fn0 = %callee_f64(f64) -> f64 tail + ; sig0 = (f64) -> f64 tail + +block0(v0: f64): + v1 = func_addr.i64 fn0 + return_call_indirect sig0, v1(v0) +} +; run: %call_f64(0x10.0) == 0x20.0 + +;;;; Test passing `i8`s ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + +function %callee_i8(i8) -> i8 tail { +block0(v0: i8): + v1 = iconst.i8 0 + v2 = icmp eq v0, v1 + return v2 +} + +function %call_i8(i8) -> i8 tail { + fn0 = %callee_i8(i8) -> i8 tail + ; sig0 = (i8) -> i8 tail + +block0(v0: i8): + v1 = func_addr.i64 fn0 + return_call_indirect sig0, v1(v0) +} +; run: %call_i8(1) == 0 +; run: %call_i8(0) == 1 diff --git a/cranelift/interpreter/src/address.rs b/cranelift/interpreter/src/address.rs index 1d1f831141..97d610d27f 100644 --- a/cranelift/interpreter/src/address.rs +++ b/cranelift/interpreter/src/address.rs @@ -15,8 +15,8 @@ //! are the "entry" field, the amount of "entry" bits depends on the size of the address and //! the "region" of the address. The remaining bits belong to the "offset" field //! -//! An example address could be a 32 bit address, in the `heap` region, which has 2 "entry" bits -//! this address would have 32 - 2 - 2 = 28 offset bits. +//! An example address could be a 32 bit address, in the `function` region, which has 1 "entry" bit +//! this address would have 32 - 1 - 2 = 29 offset bits. //! //! The only exception to this is the "stack" region, where, because we only have a single "stack" //! we have 0 "entry" bits, and thus is all offset. @@ -24,11 +24,11 @@ //! | address size | address kind | region value (2 bits) | entry bits (#) | offset bits (#) | //! |--------------|--------------|-----------------------|----------------|-----------------| //! | 32 | Stack | 0b00 | 0 | 30 | -//! | 32 | Heap | 0b01 | 2 | 28 | +//! | 32 | Function | 0b01 | 1 | 29 | //! | 32 | Table | 0b10 | 5 | 25 | //! | 32 | GlobalValue | 0b11 | 6 | 24 | //! | 64 | Stack | 0b00 | 0 | 62 | -//! | 64 | Heap | 0b01 | 6 | 56 | +//! | 64 | Function | 0b01 | 1 | 61 | //! | 64 | Table | 0b10 | 10 | 52 | //! | 64 | GlobalValue | 0b11 | 12 | 50 | @@ -68,7 +68,7 @@ impl TryFrom for AddressSize { #[derive(Debug, Copy, Clone, PartialEq)] pub enum AddressRegion { Stack, - Heap, + Function, Table, GlobalValue, } @@ -78,7 +78,7 @@ impl AddressRegion { assert!(bits < 4); match bits { 0 => AddressRegion::Stack, - 1 => AddressRegion::Heap, + 1 => AddressRegion::Function, 2 => AddressRegion::Table, 3 => AddressRegion::GlobalValue, _ => unreachable!(), @@ -88,7 +88,7 @@ impl AddressRegion { pub fn encode(self) -> u64 { match self { AddressRegion::Stack => 0, - AddressRegion::Heap => 1, + AddressRegion::Function => 1, AddressRegion::Table => 2, AddressRegion::GlobalValue => 3, } @@ -143,11 +143,13 @@ impl Address { // We only have one stack, so the whole address is offset (_, AddressRegion::Stack) => 0, - (AddressSize::_32, AddressRegion::Heap) => 2, + // We have two function "entries", one for libcalls, and + // another for user functions. + (_, AddressRegion::Function) => 1, + (AddressSize::_32, AddressRegion::Table) => 5, (AddressSize::_32, AddressRegion::GlobalValue) => 6, - (AddressSize::_64, AddressRegion::Heap) => 6, (AddressSize::_64, AddressRegion::Table) => 10, (AddressSize::_64, AddressRegion::GlobalValue) => 12, } @@ -224,6 +226,22 @@ impl TryFrom for Address { } } +#[derive(Debug, Clone, PartialEq)] +pub enum AddressFunctionEntry { + UserFunction = 0, + LibCall, +} + +impl From for AddressFunctionEntry { + fn from(bits: u64) -> Self { + match bits { + 0 => AddressFunctionEntry::UserFunction, + 1 => AddressFunctionEntry::LibCall, + _ => unreachable!(), + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -233,7 +251,7 @@ mod tests { fn address_region_roundtrip_encode_decode() { let all_regions = [ AddressRegion::Stack, - AddressRegion::Heap, + AddressRegion::Function, AddressRegion::Table, AddressRegion::GlobalValue, ]; @@ -250,10 +268,10 @@ mod tests { (AddressSize::_32, AddressRegion::Stack, 0, 1), (AddressSize::_32, AddressRegion::Stack, 0, 1024), (AddressSize::_32, AddressRegion::Stack, 0, 0x3FFF_FFFF), - (AddressSize::_32, AddressRegion::Heap, 0, 0), - (AddressSize::_32, AddressRegion::Heap, 1, 1), - (AddressSize::_32, AddressRegion::Heap, 3, 1024), - (AddressSize::_32, AddressRegion::Heap, 3, 0x0FFF_FFFF), + (AddressSize::_32, AddressRegion::Function, 0, 0), + (AddressSize::_32, AddressRegion::Function, 1, 1), + (AddressSize::_32, AddressRegion::Function, 0, 1024), + (AddressSize::_32, AddressRegion::Function, 1, 0x0FFF_FFFF), (AddressSize::_32, AddressRegion::Table, 0, 0), (AddressSize::_32, AddressRegion::Table, 1, 1), (AddressSize::_32, AddressRegion::Table, 31, 0x1FF_FFFF), @@ -268,10 +286,10 @@ mod tests { 0, 0x3FFFFFFF_FFFFFFFF, ), - (AddressSize::_64, AddressRegion::Heap, 0, 0), - (AddressSize::_64, AddressRegion::Heap, 1, 1), - (AddressSize::_64, AddressRegion::Heap, 3, 1024), - (AddressSize::_64, AddressRegion::Heap, 3, 0x0FFF_FFFF), + (AddressSize::_64, AddressRegion::Function, 0, 0), + (AddressSize::_64, AddressRegion::Function, 1, 1), + (AddressSize::_64, AddressRegion::Function, 0, 1024), + (AddressSize::_64, AddressRegion::Function, 1, 0x0FFF_FFFF), (AddressSize::_64, AddressRegion::Table, 0, 0), (AddressSize::_64, AddressRegion::Table, 1, 1), (AddressSize::_64, AddressRegion::Table, 31, 0x1FF_FFFF), diff --git a/cranelift/interpreter/src/interpreter.rs b/cranelift/interpreter/src/interpreter.rs index 7a5c00a163..290901dc44 100644 --- a/cranelift/interpreter/src/interpreter.rs +++ b/cranelift/interpreter/src/interpreter.rs @@ -2,17 +2,17 @@ //! //! This module partially contains the logic for interpreting Cranelift IR. -use crate::address::{Address, AddressRegion, AddressSize}; +use crate::address::{Address, AddressFunctionEntry, AddressRegion, AddressSize}; use crate::environment::{FuncIndex, FunctionStore}; use crate::frame::Frame; use crate::instruction::DfgInstructionContext; -use crate::state::{MemoryError, State}; +use crate::state::{InterpreterFunctionRef, MemoryError, State}; use crate::step::{step, ControlFlow, StepError}; use crate::value::{Value, ValueError}; use cranelift_codegen::data_value::DataValue; use cranelift_codegen::ir::{ - ArgumentPurpose, Block, FuncRef, Function, GlobalValue, GlobalValueData, LibCall, StackSlot, - TrapCode, Type, Value as ValueRef, + ArgumentPurpose, Block, ExternalName, FuncRef, Function, GlobalValue, GlobalValueData, LibCall, + StackSlot, TrapCode, Type, Value as ValueRef, }; use log::trace; use smallvec::SmallVec; @@ -346,6 +346,63 @@ impl<'a> State<'a, DataValue> for InterpreterState<'a> { Ok(v.write_to_slice(dst)) } + fn function_address( + &self, + size: AddressSize, + name: &ExternalName, + ) -> Result { + let curr_func = self.get_current_function(); + let (entry, index) = match name { + ExternalName::User(username) => { + let ext_name = &curr_func.params.user_named_funcs()[*username]; + + // TODO: This is not optimal since we are looking up by string name + let index = self.functions.index_of(&ext_name.to_string()).unwrap(); + + (AddressFunctionEntry::UserFunction, index.as_u32()) + } + + ExternalName::TestCase(testname) => { + // TODO: This is not optimal since we are looking up by string name + let index = self.functions.index_of(&testname.to_string()).unwrap(); + + (AddressFunctionEntry::UserFunction, index.as_u32()) + } + ExternalName::LibCall(libcall) => { + // We don't properly have a "libcall" store, but we can use `LibCall::all()` + // and index into that. + let index = LibCall::all_libcalls() + .iter() + .position(|lc| lc == libcall) + .unwrap(); + + (AddressFunctionEntry::LibCall, index as u32) + } + _ => unimplemented!("function_address: {:?}", name), + }; + + Address::from_parts(size, AddressRegion::Function, entry as u64, index as u64) + } + + fn get_function_from_address(&self, address: Address) -> Option> { + let index = address.offset as u32; + if address.region != AddressRegion::Function { + return None; + } + + match AddressFunctionEntry::from(address.entry) { + AddressFunctionEntry::UserFunction => self + .functions + .get_by_index(FuncIndex::from_u32(index)) + .map(InterpreterFunctionRef::from), + + AddressFunctionEntry::LibCall => LibCall::all_libcalls() + .get(index as usize) + .copied() + .map(InterpreterFunctionRef::from), + } + } + /// Non-Recursively resolves a global value until its address is found fn resolve_global_value(&self, gv: GlobalValue) -> Result { // Resolving a Global Value is a "pointer" chasing operation that lends itself to diff --git a/cranelift/interpreter/src/state.rs b/cranelift/interpreter/src/state.rs index 49526e4a83..b98f0c5525 100644 --- a/cranelift/interpreter/src/state.rs +++ b/cranelift/interpreter/src/state.rs @@ -3,7 +3,10 @@ use crate::address::{Address, AddressSize}; use crate::interpreter::LibCallHandler; use cranelift_codegen::data_value::DataValue; -use cranelift_codegen::ir::{FuncRef, Function, GlobalValue, StackSlot, Type, Value}; +use cranelift_codegen::ir::{ + ExternalName, FuncRef, Function, GlobalValue, LibCall, Signature, StackSlot, Type, Value, +}; +use cranelift_codegen::isa::CallConv; use cranelift_entity::PrimaryMap; use smallvec::SmallVec; use thiserror::Error; @@ -64,6 +67,16 @@ pub trait State<'a, V> { /// stack or to one of the heaps; the number of bytes stored corresponds to the specified [Type]. fn checked_store(&mut self, address: Address, v: V) -> Result<(), MemoryError>; + /// Compute the address of a function given its name. + fn function_address( + &self, + size: AddressSize, + name: &ExternalName, + ) -> Result; + + /// Retrieve a reference to a [Function] given its address. + fn get_function_from_address(&self, address: Address) -> Option>; + /// Given a global value, compute the final value for that global value, applying all operations /// in intermediate global values. fn resolve_global_value(&self, gv: GlobalValue) -> Result; @@ -77,6 +90,33 @@ pub trait State<'a, V> { fn set_pinned_reg(&mut self, v: V); } +pub enum InterpreterFunctionRef<'a> { + Function(&'a Function), + LibCall(LibCall), +} + +impl<'a> InterpreterFunctionRef<'a> { + pub fn signature(&self) -> Signature { + match self { + InterpreterFunctionRef::Function(f) => f.stencil.signature.clone(), + // CallConv here is sort of irrelevant, since we don't use it for anything + InterpreterFunctionRef::LibCall(lc) => lc.signature(CallConv::SystemV), + } + } +} + +impl<'a> From<&'a Function> for InterpreterFunctionRef<'a> { + fn from(f: &'a Function) -> Self { + InterpreterFunctionRef::Function(f) + } +} + +impl From for InterpreterFunctionRef<'_> { + fn from(lc: LibCall) -> Self { + InterpreterFunctionRef::LibCall(lc) + } +} + #[derive(Error, Debug)] pub enum MemoryError { #[error("Invalid DataValue passed as an address: {0}")] @@ -150,6 +190,18 @@ where unimplemented!() } + fn function_address( + &self, + _size: AddressSize, + _name: &ExternalName, + ) -> Result { + unimplemented!() + } + + fn get_function_from_address(&self, _address: Address) -> Option> { + unimplemented!() + } + fn resolve_global_value(&self, _gv: GlobalValue) -> Result { unimplemented!() } diff --git a/cranelift/interpreter/src/step.rs b/cranelift/interpreter/src/step.rs index c77a5fa43c..4f230072a4 100644 --- a/cranelift/interpreter/src/step.rs +++ b/cranelift/interpreter/src/step.rs @@ -2,7 +2,7 @@ //! [InstructionContext]; the interpretation is generic over [Value]s. use crate::address::{Address, AddressSize}; use crate::instruction::InstructionContext; -use crate::state::{MemoryError, State}; +use crate::state::{InterpreterFunctionRef, MemoryError, State}; use crate::value::{Value, ValueConversionKind, ValueError, ValueResult}; use cranelift_codegen::data_value::DataValue; use cranelift_codegen::ir::condcodes::{FloatCC, IntCC}; @@ -276,35 +276,12 @@ where } }; - // Perform a call operation. - // - // The returned `ControlFlow` variant is determined by the given function - // argument, which should make either a `ControlFlow::Call` or a - // `ControlFlow::ReturnCall`. - let do_call = |make_ctrl_flow: fn(&'a Function, SmallVec<[V; 1]>) -> ControlFlow<'a, V>| + // Calls a function reference with the given arguments. + let call_func = |func_ref: InterpreterFunctionRef<'a>, + args: SmallVec<[V; 1]>, + make_ctrl_flow: fn(&'a Function, SmallVec<[V; 1]>) -> ControlFlow<'a, V>| -> Result, StepError> { - let func_ref = if let InstructionData::Call { func_ref, .. } = inst { - func_ref - } else { - unreachable!() - }; - - let curr_func = state.get_current_function(); - let ext_data = curr_func - .dfg - .ext_funcs - .get(func_ref) - .ok_or(StepError::UnknownFunction(func_ref))?; - - let signature = if let Some(sig) = curr_func.dfg.signatures.get(ext_data.signature) { - sig - } else { - return Ok(ControlFlow::Trap(CraneliftTrap::User( - TrapCode::BadSignature, - ))); - }; - - let args = args()?; + let signature = func_ref.signature(); // Check the types of the arguments. This is usually done by the verifier, but nothing // guarantees that the user has ran that. @@ -315,17 +292,16 @@ where ))); } - Ok(match ext_data.name { - // These functions should be registered in the regular function store - ExternalName::User(_) | ExternalName::TestCase(_) => { - let function = state - .get_function(func_ref) - .ok_or(StepError::UnknownFunction(func_ref))?; - - make_ctrl_flow(function, args) - } - ExternalName::LibCall(libcall) => { - debug_assert_ne!(inst.opcode(), Opcode::ReturnCall, "Cannot tail call to libcalls"); + Ok(match func_ref { + InterpreterFunctionRef::Function(func) => make_ctrl_flow(func, args), + InterpreterFunctionRef::LibCall(libcall) => { + debug_assert!( + !matches!( + inst.opcode(), + Opcode::ReturnCall | Opcode::ReturnCallIndirect, + ), + "Cannot tail call to libcalls" + ); let libcall_handler = state.get_libcall_handler(); // We don't transfer control to a libcall, we just execute it and return the results @@ -342,7 +318,6 @@ where ControlFlow::Trap(CraneliftTrap::User(TrapCode::BadSignature)) } } - ExternalName::KnownSymbol(_) => unimplemented!(), }) }; @@ -398,11 +373,83 @@ where Opcode::Trapnz => trap_when(arg(0)?.into_bool()?, CraneliftTrap::User(trap_code())), Opcode::ResumableTrapnz => trap_when(arg(0)?.into_bool()?, CraneliftTrap::Resumable), Opcode::Return => ControlFlow::Return(args()?), - Opcode::Call => do_call(ControlFlow::Call)?, - Opcode::CallIndirect => unimplemented!("CallIndirect"), - Opcode::ReturnCall => do_call(ControlFlow::ReturnCall)?, - Opcode::ReturnCallIndirect => unimplemented!("ReturnCallIndirect"), - Opcode::FuncAddr => unimplemented!("FuncAddr"), + Opcode::Call | Opcode::ReturnCall => { + let func_ref = if let InstructionData::Call { func_ref, .. } = inst { + func_ref + } else { + unreachable!() + }; + + let curr_func = state.get_current_function(); + let ext_data = curr_func + .dfg + .ext_funcs + .get(func_ref) + .ok_or(StepError::UnknownFunction(func_ref))?; + + let args = args()?; + let func = match ext_data.name { + // These functions should be registered in the regular function store + ExternalName::User(_) | ExternalName::TestCase(_) => { + let function = state + .get_function(func_ref) + .ok_or(StepError::UnknownFunction(func_ref))?; + InterpreterFunctionRef::Function(function) + } + ExternalName::LibCall(libcall) => InterpreterFunctionRef::LibCall(libcall), + ExternalName::KnownSymbol(_) => unimplemented!(), + }; + + let make_control_flow = match inst.opcode() { + Opcode::Call => ControlFlow::Call, + Opcode::ReturnCall => ControlFlow::ReturnCall, + _ => unreachable!(), + }; + + call_func(func, args, make_control_flow)? + } + Opcode::CallIndirect | Opcode::ReturnCallIndirect => { + let args = args()?; + let addr_dv = DataValue::U64(arg(0)?.into_int()? as u64); + let addr = Address::try_from(addr_dv.clone()).map_err(StepError::MemoryError)?; + + let func = state + .get_function_from_address(addr) + .ok_or_else(|| StepError::MemoryError(MemoryError::InvalidAddress(addr_dv)))?; + + let call_args: SmallVec<[V; 1]> = SmallVec::from(&args[1..]); + + let make_control_flow = match inst.opcode() { + Opcode::CallIndirect => ControlFlow::Call, + Opcode::ReturnCallIndirect => ControlFlow::ReturnCall, + _ => unreachable!(), + }; + + call_func(func, call_args, make_control_flow)? + } + Opcode::FuncAddr => { + let func_ref = if let InstructionData::FuncAddr { func_ref, .. } = inst { + func_ref + } else { + unreachable!() + }; + + let ext_data = state + .get_current_function() + .dfg + .ext_funcs + .get(func_ref) + .ok_or(StepError::UnknownFunction(func_ref))?; + + let addr_ty = inst_context.controlling_type().unwrap(); + assign_or_memtrap({ + AddressSize::try_from(addr_ty).and_then(|addr_size| { + let addr = state.function_address(addr_size, &ext_data.name)?; + let dv = DataValue::try_from(addr)?; + Ok(dv.into()) + }) + }) + } Opcode::Load | Opcode::Uload8 | Opcode::Sload8