Files
wasmtime/crates/runtime/src/traphandlers.rs
Alex Crichton 80b095f2e2 Add API to statically assert signature of a Func (#955)
* Add API to statically assert signature of a `Func`

This commit add a family of APIs to `Func` named `getN` where `N` is the
number of arguments. Each function will attempt to statically assert the
signature of a `Func` and, if matching, returns a corresponding closure
which can be used to invoke the underlying function.

The purpose of this commit is to add a highly optimized way to enter a
wasm module, performing type checks up front and avoiding all the costs
of boxing and unboxing arguments within a `Val`. In general this should
be much more optimized than the previous `call` API for entering a wasm
module, if the signature is statically known.

* rustfmt

* Remove stray debugging
2020-02-20 09:28:12 -06:00

370 lines
12 KiB
Rust

//! WebAssembly trap handling, which is built on top of the lower-level
//! signalhandling mechanisms.
use crate::instance::{InstanceHandle, SignalHandler};
use crate::trap_registry::TrapDescription;
use crate::vmcontext::{VMContext, VMFunctionBody};
use backtrace::Backtrace;
use std::any::Any;
use std::cell::Cell;
use std::error::Error;
use std::fmt;
use std::mem;
use std::ptr;
use wasmtime_environ::ir;
extern "C" {
fn RegisterSetjmp(
jmp_buf: *mut *const u8,
callback: extern "C" fn(*mut u8),
payload: *mut u8,
) -> i32;
fn Unwind(jmp_buf: *const u8) -> !;
}
cfg_if::cfg_if! {
if #[cfg(unix)] {
#[no_mangle]
pub unsafe extern "C" fn HandleTrap(
pc: *mut u8,
signum: libc::c_int,
siginfo: *mut libc::siginfo_t,
context: *mut libc::c_void,
) -> *const u8 {
tls::with(|info| {
match info {
Some(info) => info.handle_trap(pc, false, |handler| handler(signum, siginfo, context)),
None => ptr::null(),
}
})
}
} else if #[cfg(target_os = "windows")] {
use winapi::um::winnt::PEXCEPTION_POINTERS;
use winapi::um::minwinbase::EXCEPTION_STACK_OVERFLOW;
#[no_mangle]
pub unsafe extern "C" fn HandleTrap(
pc: *mut u8,
exception_info: PEXCEPTION_POINTERS
) -> *const u8 {
tls::with(|info| {
let reset_guard_page = (*(*exception_info).ExceptionRecord).ExceptionCode == EXCEPTION_STACK_OVERFLOW;
match info {
Some(info) => info.handle_trap(pc, reset_guard_page, |handler| handler(exception_info)),
None => ptr::null(),
}
})
}
}
}
/// Raises a user-defined trap immediately.
///
/// This function performs as-if a wasm trap was just executed, only the trap
/// has a dynamic payload associated with it which is user-provided. This trap
/// payload is then returned from `wasmtime_call` an `wasmtime_call_trampoline`
/// below.
///
/// # Safety
///
/// Only safe to call when wasm code is on the stack, aka `wasmtime_call` or
/// `wasmtime_call_trampoline` must have been previously called.
pub unsafe fn raise_user_trap(data: Box<dyn Error + Send + Sync>) -> ! {
tls::with(|info| info.unwrap().unwind_with(UnwindReason::UserTrap(data)))
}
/// Carries a Rust panic across wasm code and resumes the panic on the other
/// side.
///
/// # Safety
///
/// Only safe to call when wasm code is on the stack, aka `wasmtime_call` or
/// `wasmtime_call_trampoline` must have been previously called.
pub unsafe fn resume_panic(payload: Box<dyn Any + Send>) -> ! {
tls::with(|info| info.unwrap().unwind_with(UnwindReason::Panic(payload)))
}
#[cfg(target_os = "windows")]
fn reset_guard_page() {
extern "C" {
fn _resetstkoflw() -> winapi::ctypes::c_int;
}
// We need to restore guard page under stack to handle future stack overflows properly.
// https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/resetstkoflw?view=vs-2019
if unsafe { _resetstkoflw() } == 0 {
panic!("failed to restore stack guard page");
}
}
#[cfg(not(target_os = "windows"))]
fn reset_guard_page() {}
/// Stores trace message with backtrace.
#[derive(Debug)]
pub enum Trap {
/// A user-raised trap through `raise_user_trap`.
User(Box<dyn Error + Send + Sync>),
/// A wasm-originating trap from wasm code itself.
Wasm {
/// What sort of trap happened, as well as where in the original wasm module
/// it happened.
desc: TrapDescription,
/// Native stack backtrace at the time the trap occurred
backtrace: Backtrace,
},
}
impl fmt::Display for Trap {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Trap::User(user) => user.fmt(f),
Trap::Wasm { desc, .. } => desc.fmt(f),
}
}
}
impl std::error::Error for Trap {}
/// Call the wasm function pointed to by `callee`.
///
/// * `vmctx` - the callee vmctx argument
/// * `caller_vmctx` - the caller vmctx argument
/// * `trampoline` - the jit-generated trampoline whose ABI takes 4 values, the
/// callee vmctx, the caller vmctx, the `callee` argument below, and then the
/// `values_vec` argument.
/// * `callee` - the third argument to the `trampoline` function
/// * `values_vec` - points to a buffer which holds the incoming arguments, and to
/// which the outgoing return values will be written.
///
/// Wildly unsafe because it calls raw function pointers and reads/writes raw
/// function pointers.
pub unsafe fn wasmtime_call_trampoline(
vmctx: *mut VMContext,
caller_vmctx: *mut VMContext,
trampoline: *const VMFunctionBody,
callee: *const VMFunctionBody,
values_vec: *mut u8,
) -> Result<(), Trap> {
catch_traps(vmctx, || {
mem::transmute::<
_,
extern "C" fn(*mut VMContext, *mut VMContext, *const VMFunctionBody, *mut u8),
>(trampoline)(vmctx, caller_vmctx, callee, values_vec)
})
}
/// Catches any wasm traps that happen within the execution of `closure`,
/// returning them as a `Result`.
///
/// Highly unsafe since `closure` won't have any dtors run.
pub unsafe fn catch_traps<F>(vmctx: *mut VMContext, mut closure: F) -> Result<(), Trap>
where
F: FnMut(),
{
return CallThreadState::new(vmctx).with(|cx| {
RegisterSetjmp(
cx.jmp_buf.as_ptr(),
call_closure::<F>,
&mut closure as *mut F as *mut u8,
)
});
extern "C" fn call_closure<F>(payload: *mut u8)
where
F: FnMut(),
{
unsafe { (*(payload as *mut F))() }
}
}
/// Temporary state stored on the stack which is registered in the `tls` module
/// below for calls into wasm.
pub struct CallThreadState {
unwind: Cell<UnwindReason>,
jmp_buf: Cell<*const u8>,
reset_guard_page: Cell<bool>,
prev: Option<*const CallThreadState>,
vmctx: *mut VMContext,
}
enum UnwindReason {
None,
Panic(Box<dyn Any + Send>),
UserTrap(Box<dyn Error + Send + Sync>),
Trap { backtrace: Backtrace, pc: usize },
}
impl CallThreadState {
fn new(vmctx: *mut VMContext) -> CallThreadState {
CallThreadState {
unwind: Cell::new(UnwindReason::None),
vmctx,
jmp_buf: Cell::new(ptr::null()),
reset_guard_page: Cell::new(false),
prev: None,
}
}
fn with(mut self, closure: impl FnOnce(&CallThreadState) -> i32) -> Result<(), Trap> {
tls::with(|prev| {
self.prev = prev.map(|p| p as *const _);
let ret = tls::set(&self, || closure(&self));
match self.unwind.replace(UnwindReason::None) {
UnwindReason::None => {
debug_assert_eq!(ret, 1);
Ok(())
}
UnwindReason::UserTrap(data) => {
debug_assert_eq!(ret, 0);
Err(Trap::User(data))
}
UnwindReason::Trap { backtrace, pc } => {
debug_assert_eq!(ret, 0);
let instance = unsafe { InstanceHandle::from_vmctx(self.vmctx) };
Err(Trap::Wasm {
desc: instance
.instance()
.trap_registration
.get_trap(pc)
.unwrap_or_else(|| TrapDescription {
source_loc: ir::SourceLoc::default(),
trap_code: ir::TrapCode::StackOverflow,
}),
backtrace,
})
}
UnwindReason::Panic(panic) => {
debug_assert_eq!(ret, 0);
std::panic::resume_unwind(panic)
}
}
})
}
fn any_instance(&self, func: impl Fn(&InstanceHandle) -> bool) -> bool {
unsafe {
if func(&InstanceHandle::from_vmctx(self.vmctx)) {
return true;
}
match self.prev {
Some(prev) => (*prev).any_instance(func),
None => false,
}
}
}
fn unwind_with(&self, reason: UnwindReason) -> ! {
self.unwind.replace(reason);
unsafe {
Unwind(self.jmp_buf.get());
}
}
/// Trap handler using our thread-local state.
///
/// * `pc` - the program counter the trap happened at
/// * `reset_guard_page` - whether or not to reset the guard page,
/// currently Windows specific
/// * `call_handler` - a closure used to invoke the platform-specific
/// signal handler for each instance, if available.
///
/// Attempts to handle the trap if it's a wasm trap. Returns a few
/// different things:
///
/// * null - the trap didn't look like a wasm trap and should continue as a
/// trap
/// * 1 as a pointer - the trap was handled by a custom trap handler on an
/// instance, and the trap handler should quickly return.
/// * a different pointer - a jmp_buf buffer to longjmp to, meaning that
/// the wasm trap was succesfully handled.
fn handle_trap(
&self,
pc: *const u8,
reset_guard_page: bool,
call_handler: impl Fn(&SignalHandler) -> bool,
) -> *const u8 {
// First up see if any instance registered has a custom trap handler,
// in which case run them all. If anything handles the trap then we
// return that the trap was handled.
if self.any_instance(|i| {
let handler = match i.instance().signal_handler.replace(None) {
Some(handler) => handler,
None => return false,
};
let result = call_handler(&handler);
i.instance().signal_handler.set(Some(handler));
return result;
}) {
return 1 as *const _;
}
// TODO: stack overflow can happen at any random time (i.e. in malloc()
// in memory.grow) and it's really hard to determine if the cause was
// stack overflow and if it happened in WebAssembly module.
//
// So, let's assume that any untrusted code called from WebAssembly
// doesn't trap. Then, if we have called some WebAssembly code, it
// means the trap is stack overflow.
if self.jmp_buf.get().is_null() {
return ptr::null();
}
let backtrace = Backtrace::new_unresolved();
self.reset_guard_page.set(reset_guard_page);
self.unwind.replace(UnwindReason::Trap {
backtrace,
pc: pc as usize,
});
self.jmp_buf.get()
}
}
impl Drop for CallThreadState {
fn drop(&mut self) {
if self.reset_guard_page.get() {
reset_guard_page();
}
}
}
// A private inner module for managing the TLS state that we require across
// calls in wasm. The WebAssembly code is called from C++ and then a trap may
// happen which requires us to read some contextual state to figure out what to
// do with the trap. This `tls` module is used to persist that information from
// the caller to the trap site.
mod tls {
use super::CallThreadState;
use std::cell::Cell;
use std::ptr;
thread_local!(static PTR: Cell<*const CallThreadState> = Cell::new(ptr::null()));
/// Configures thread local state such that for the duration of the
/// execution of `closure` any call to `with` will yield `ptr`, unless this
/// is recursively called again.
pub fn set<R>(ptr: &CallThreadState, closure: impl FnOnce() -> R) -> R {
struct Reset<'a, T: Copy>(&'a Cell<T>, T);
impl<T: Copy> Drop for Reset<'_, T> {
fn drop(&mut self) {
self.0.set(self.1);
}
}
PTR.with(|p| {
let _r = Reset(p, p.replace(ptr));
closure()
})
}
/// Returns the last pointer configured with `set` above. Panics if `set`
/// has not been previously called.
pub fn with<R>(closure: impl FnOnce(Option<&CallThreadState>) -> R) -> R {
PTR.with(|ptr| {
let p = ptr.get();
unsafe { closure(if p.is_null() { None } else { Some(&*p) }) }
})
}
}