Don't try to handle non-wasmtime segfaults (#1577)

This commit fixes an issue in Wasmtime where Wasmtime would accidentally
"handle" non-wasm segfaults while executing host imports of wasm
modules. If a host import segfaulted then Wasmtime would recognize that
wasm code is on the stack, so it'd longjmp out of the wasm code. This
papers over real bugs though in host code and erroneously classified
segfaults as wasm traps.

The fix here was to add a check to our wasm signal handler for if the
faulting address falls in JIT code itself. Actually threading through
all the right information for that check to happen is a bit tricky,
though, so this involved some refactoring:

* A closure parameter to `catch_traps` was added. This closure is
  responsible for classifying addresses as whether or not they fall in
  JIT code. Anything returning `false` means that the trap won't get
  handled and we'll forward to the next signal handler.

* To avoid passing tons of context all over the place, the start
  function is now no longer automatically invoked by `InstanceHandle`.
  This avoids the need for passing all sorts of trap-handling contextual
  information like the maximum stack size and "is this a jit address"
  closure. Instead creators of `InstanceHandle` (like wasmtime) are now
  responsible for invoking the start function.

* To avoid excessive use of `transmute` with lifetimes since the
  traphandler state now has a lifetime the per-instance custom signal
  handler is now replaced with a per-store custom signal handler. I'm
  not entirely certain the purpose of the custom signal handler, though,
  so I'd look for feedback on this part.

A new test has been added which ensures that if a host function
segfaults we don't accidentally try to handle it, and instead we
correctly report the segfault.
This commit is contained in:
Alex Crichton
2020-04-29 14:24:54 -05:00
committed by GitHub
parent 8ee8c322ae
commit d719ec7e1c
12 changed files with 214 additions and 222 deletions

View File

@@ -1,7 +1,6 @@
//! WebAssembly trap handling, which is built on top of the lower-level
//! signalhandling mechanisms.
use crate::instance::{InstanceHandle, SignalHandler};
use crate::VMContext;
use backtrace::Backtrace;
use std::any::Any;
@@ -26,6 +25,9 @@ cfg_if::cfg_if! {
if #[cfg(unix)] {
use std::mem::{self, MaybeUninit};
/// Function which may handle custom signals while processing traps.
pub type SignalHandler<'a> = dyn Fn(libc::c_int, *const libc::siginfo_t, *const libc::c_void) -> bool + 'a;
static mut PREV_SIGSEGV: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
static mut PREV_SIGBUS: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
static mut PREV_SIGILL: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
@@ -180,6 +182,9 @@ cfg_if::cfg_if! {
use winapi::um::minwinbase::*;
use winapi::vc::excpt::*;
/// Function which may handle custom signals while processing traps.
pub type SignalHandler<'a> = dyn Fn(winapi::um::winnt::PEXCEPTION_POINTERS) -> bool + 'a;
unsafe fn platform_init() {
// our trap handler needs to go first, so that we can recover from
// wasm faults and continue execution, so pass `1` as a true value
@@ -361,6 +366,8 @@ impl Trap {
pub unsafe fn catch_traps<F>(
vmctx: *mut VMContext,
max_wasm_stack: usize,
is_wasm_code: impl Fn(usize) -> bool,
signal_handler: Option<&SignalHandler>,
mut closure: F,
) -> Result<(), Trap>
where
@@ -370,7 +377,7 @@ where
#[cfg(unix)]
setup_unix_sigaltstack()?;
return CallThreadState::new(vmctx).with(max_wasm_stack, |cx| {
return CallThreadState::new(vmctx, &is_wasm_code, signal_handler).with(max_wasm_stack, |cx| {
RegisterSetjmp(
cx.jmp_buf.as_ptr(),
call_closure::<F>,
@@ -388,12 +395,13 @@ where
/// Temporary state stored on the stack which is registered in the `tls` module
/// below for calls into wasm.
pub struct CallThreadState {
pub struct CallThreadState<'a> {
unwind: Cell<UnwindReason>,
jmp_buf: Cell<*const u8>,
prev: Option<*const CallThreadState>,
vmctx: *mut VMContext,
handling_trap: Cell<bool>,
is_wasm_code: &'a (dyn Fn(usize) -> bool + 'a),
signal_handler: Option<&'a SignalHandler<'a>>,
}
enum UnwindReason {
@@ -404,54 +412,56 @@ enum UnwindReason {
JitTrap { backtrace: Backtrace, pc: usize },
}
impl CallThreadState {
fn new(vmctx: *mut VMContext) -> CallThreadState {
impl<'a> CallThreadState<'a> {
fn new(
vmctx: *mut VMContext,
is_wasm_code: &'a (dyn Fn(usize) -> bool + 'a),
signal_handler: Option<&'a SignalHandler<'a>>,
) -> CallThreadState<'a> {
CallThreadState {
unwind: Cell::new(UnwindReason::None),
vmctx,
jmp_buf: Cell::new(ptr::null()),
prev: None,
handling_trap: Cell::new(false),
is_wasm_code,
signal_handler,
}
}
fn with(
mut self,
self,
max_wasm_stack: usize,
closure: impl FnOnce(&CallThreadState) -> i32,
) -> Result<(), Trap> {
tls::with(|prev| {
self.prev = prev.map(|p| p as *const _);
let _reset = self.update_stack_limit(max_wasm_stack)?;
let ret = tls::set(&self, || closure(&self));
match self.unwind.replace(UnwindReason::None) {
UnwindReason::None => {
debug_assert_eq!(ret, 1);
Ok(())
}
UnwindReason::UserTrap(data) => {
debug_assert_eq!(ret, 0);
Err(Trap::User(data))
}
UnwindReason::LibTrap(trap) => Err(trap),
UnwindReason::JitTrap { backtrace, pc } => {
debug_assert_eq!(ret, 0);
let maybe_interrupted = unsafe {
(*self.vmctx).instance().interrupts.stack_limit.load(SeqCst)
== wasmtime_environ::INTERRUPTED
};
Err(Trap::Jit {
pc,
backtrace,
maybe_interrupted,
})
}
UnwindReason::Panic(panic) => {
debug_assert_eq!(ret, 0);
std::panic::resume_unwind(panic)
}
let _reset = self.update_stack_limit(max_wasm_stack)?;
let ret = tls::set(&self, || closure(&self));
match self.unwind.replace(UnwindReason::None) {
UnwindReason::None => {
debug_assert_eq!(ret, 1);
Ok(())
}
})
UnwindReason::UserTrap(data) => {
debug_assert_eq!(ret, 0);
Err(Trap::User(data))
}
UnwindReason::LibTrap(trap) => Err(trap),
UnwindReason::JitTrap { backtrace, pc } => {
debug_assert_eq!(ret, 0);
let maybe_interrupted = unsafe {
(*self.vmctx).instance().interrupts.stack_limit.load(SeqCst)
== wasmtime_environ::INTERRUPTED
};
Err(Trap::Jit {
pc,
backtrace,
maybe_interrupted,
})
}
UnwindReason::Panic(panic) => {
debug_assert_eq!(ret, 0);
std::panic::resume_unwind(panic)
}
}
}
/// Checks and/or initializes the wasm native call stack limit.
@@ -535,18 +545,6 @@ impl CallThreadState {
Ok(Reset(reset_stack_limit, &interrupts.stack_limit))
}
fn any_instance(&self, func: impl Fn(&InstanceHandle) -> bool) -> bool {
unsafe {
if func(&InstanceHandle::from_vmctx(self.vmctx)) {
return true;
}
match self.prev {
Some(prev) => (*prev).any_instance(func),
None => false,
}
}
}
fn unwind_with(&self, reason: UnwindReason) -> ! {
self.unwind.replace(reason);
unsafe {
@@ -582,21 +580,25 @@ impl CallThreadState {
if self.handling_trap.replace(true) {
return ptr::null();
}
let _reset = ResetCell(&self.handling_trap, false);
// If we haven't even started to handle traps yet, bail out.
if self.jmp_buf.get().is_null() {
return ptr::null();
}
// If this fault wasn't in wasm code, then it's not our problem
if !(self.is_wasm_code)(pc as usize) {
return ptr::null();
}
// First up see if any instance registered has a custom trap handler,
// in which case run them all. If anything handles the trap then we
// return that the trap was handled.
if self.any_instance(|i| {
let handler = match i.instance().signal_handler.replace(None) {
Some(handler) => handler,
None => return false,
};
let result = call_handler(&handler);
i.instance().signal_handler.set(Some(handler));
return result;
}) {
self.handling_trap.set(false);
return 1 as *const _;
if let Some(handler) = self.signal_handler {
if call_handler(handler) {
return 1 as *const _;
}
}
// TODO: stack overflow can happen at any random time (i.e. in malloc()
@@ -607,7 +609,6 @@ impl CallThreadState {
// doesn't trap. Then, if we have called some WebAssembly code, it
// means the trap is stack overflow.
if self.jmp_buf.get().is_null() {
self.handling_trap.set(false);
return ptr::null();
}
let backtrace = Backtrace::new_unresolved();
@@ -615,11 +616,18 @@ impl CallThreadState {
backtrace,
pc: pc as usize,
});
self.handling_trap.set(false);
self.jmp_buf.get()
}
}
struct ResetCell<'a, T: Copy>(&'a Cell<T>, T);
impl<T: Copy> Drop for ResetCell<'_, T> {
fn drop(&mut self) {
self.0.set(self.1);
}
}
// A private inner module for managing the TLS state that we require across
// calls in wasm. The WebAssembly code is called from C++ and then a trap may
// happen which requires us to read some contextual state to figure out what to
@@ -628,14 +636,15 @@ impl CallThreadState {
mod tls {
use super::CallThreadState;
use std::cell::Cell;
use std::mem;
use std::ptr;
thread_local!(static PTR: Cell<*const CallThreadState> = Cell::new(ptr::null()));
thread_local!(static PTR: Cell<*const CallThreadState<'static>> = Cell::new(ptr::null()));
/// Configures thread local state such that for the duration of the
/// execution of `closure` any call to `with` will yield `ptr`, unless this
/// is recursively called again.
pub fn set<R>(ptr: &CallThreadState, closure: impl FnOnce() -> R) -> R {
pub fn set<R>(ptr: &CallThreadState<'_>, closure: impl FnOnce() -> R) -> R {
struct Reset<'a, T: Copy>(&'a Cell<T>, T);
impl<T: Copy> Drop for Reset<'_, T> {
@@ -645,6 +654,12 @@ mod tls {
}
PTR.with(|p| {
// Note that this extension of the lifetime to `'static` should be
// safe because we only ever access it below with an anonymous
// lifetime, meaning `'static` never leaks out of this module.
let ptr = unsafe {
mem::transmute::<*const CallThreadState<'_>, *const CallThreadState<'static>>(ptr)
};
let _r = Reset(p, p.replace(ptr));
closure()
})
@@ -652,7 +667,7 @@ mod tls {
/// Returns the last pointer configured with `set` above. Panics if `set`
/// has not been previously called.
pub fn with<R>(closure: impl FnOnce(Option<&CallThreadState>) -> R) -> R {
pub fn with<R>(closure: impl FnOnce(Option<&CallThreadState<'_>>) -> R) -> R {
PTR.with(|ptr| {
let p = ptr.get();
unsafe { closure(if p.is_null() { None } else { Some(&*p) }) }