Remove another thread local in instance.rs (#862)

* Remove another thread local in `instance.rs`

This commit removes another usage of `thread_local!` in the continued
effort to centralize all thread-local state per-call (or basically state
needed for traps) in one location. This removal is targeted at the
support for custom signal handlers on instances, removing the previous
stack of instances with instead a linked list of instances.

The `with_signals_on` method is no longer necessary (since it was always
called anyway) and is inferred from the first `vmctx` argument of the
entrypoints into wasm. These functions establish a linked list of
instances on the stack, if needed, to handle signals when they happen.

This involved some refactoring where some C++ glue was moved into Rust,
so now Rust handles a bit more of the signal handling logic.

* Update some inline docs about `HandleTrap`
This commit is contained in:
Alex Crichton
2020-01-31 13:45:54 +01:00
committed by GitHub
parent cc07565985
commit 97ff297683
7 changed files with 183 additions and 240 deletions

View File

@@ -148,14 +148,12 @@ impl WrappedCallable for WasmtimeFn {
// Call the trampoline.
if let Err(error) = unsafe {
self.instance.with_signals_on(|| {
wasmtime_runtime::wasmtime_call_trampoline(
vmctx,
ptr::null_mut(),
exec_code_buf,
values_vec.as_mut_ptr() as *mut u8,
)
})
wasmtime_runtime::wasmtime_call_trampoline(
vmctx,
ptr::null_mut(),
exec_code_buf,
values_vec.as_mut_ptr() as *mut u8,
)
} {
return Err(Trap::from_jit(error));
}

View File

@@ -18,13 +18,13 @@ pub trait InstanceExt {
/// TODO: needs more documentation.
unsafe fn set_signal_handler<H>(&self, handler: H)
where
H: 'static + Fn(winapi::um::winnt::EXCEPTION_POINTERS) -> bool;
H: 'static + Fn(winapi::um::winnt::PEXCEPTION_POINTERS) -> bool;
}
impl InstanceExt for Instance {
unsafe fn set_signal_handler<H>(&self, handler: H)
where
H: 'static + Fn(winapi::um::winnt::EXCEPTION_POINTERS) -> bool,
H: 'static + Fn(winapi::um::winnt::PEXCEPTION_POINTERS) -> bool,
{
self.instance_handle.clone().set_signal_handler(handler);
}

View File

@@ -190,14 +190,12 @@ pub fn invoke(
// Call the trampoline. Pass a null `caller_vmctx` argument as `invoke` is
// all about calling from the outside world rather than from an instance.
if let Err(trap) = unsafe {
instance.with_signals_on(|| {
wasmtime_call_trampoline(
callee_vmctx,
ptr::null_mut(),
exec_code_buf,
values_vec.as_mut_ptr() as *mut u8,
)
})
wasmtime_call_trampoline(
callee_vmctx,
ptr::null_mut(),
exec_code_buf,
values_vec.as_mut_ptr() as *mut u8,
)
} {
return Ok(ActionOutcome::Trapped(trap));
}

View File

@@ -399,44 +399,10 @@ struct AutoHandlingTrap
}
static
#if defined(__GNUC__) || defined(__clang__)
__attribute__ ((warn_unused_result))
#endif
bool
HandleTrap(CONTEXT* context, bool reset_guard_page)
{
assert(sAlreadyHandlingTrap);
void *JmpBuf = RecordTrap(ContextToPC(context), reset_guard_page);
if (JmpBuf == nullptr) {
return false;
}
// Unwind calls longjmp, so it doesn't run the automatic
// sAlreadhHanldingTrap cleanups, so reset it manually before doing
// a longjmp.
sAlreadyHandlingTrap = false;
#if defined(USE_APPLE_MACH_PORTS)
// Reroute the PC to run the Unwind function on the main stack after the
// handler exits. This doesn't yet work for stack overflow traps, because
// in that case the main thread doesn't have any space left to run.
assert(false); // this branch isn't implemented here
// SetContextPC(context, reinterpret_cast<const uint8_t*>(&Unwind));
#else
// For now, just call Unwind directly, rather than redirecting the PC there,
// so that it runs on the alternate signal handler stack. To run on the main
// stack, reroute the context PC like this:
Unwind(JmpBuf);
#endif
return true;
}
// =============================================================================
// The following platform-specific handlers funnel all signals/exceptions into
// the shared HandleTrap() above.
// the HandleTrap() function defined in Rust. Note that the Rust function has a
// different ABI depending on the platform.
// =============================================================================
#if defined(_WIN32)
@@ -467,16 +433,19 @@ WasmTrapHandler(LPEXCEPTION_POINTERS exception)
return EXCEPTION_CONTINUE_SEARCH;
}
bool handled = InstanceSignalHandler(exception);
void *JmpBuf = HandleTrap(ContextToPC(exception->ContextRecord), exception);
// Test if a custom instance signal handler handled the exception
if (((size_t) JmpBuf) == 1)
return EXCEPTION_CONTINUE_EXECUTION;
if (!handled) {
if (!HandleTrap(exception->ContextRecord,
record->ExceptionCode == EXCEPTION_STACK_OVERFLOW)) {
return EXCEPTION_CONTINUE_SEARCH;
}
// Otherwise test if we need to longjmp to this buffer
if (JmpBuf != nullptr) {
sAlreadyHandlingTrap = false;
Unwind(JmpBuf);
}
return EXCEPTION_CONTINUE_EXECUTION;
// ... and otherwise keep looking for a handler
return EXCEPTION_CONTINUE_SEARCH;
}
#elif defined(USE_APPLE_MACH_PORTS)
@@ -638,16 +607,21 @@ WasmTrapHandler(int signum, siginfo_t* info, void* context)
AutoHandlingTrap aht;
assert(signum == SIGSEGV || signum == SIGBUS || signum == SIGFPE || signum == SIGILL);
if (InstanceSignalHandler(signum, info, (ucontext_t*) context)) {
void *JmpBuf = HandleTrap(ContextToPC(static_cast<CONTEXT*>(context)), signum, info, context);
// Test if a custom instance signal handler handled the exception
if (((size_t) JmpBuf) == 1)
return;
// Otherwise test if we need to longjmp to this buffer
if (JmpBuf != nullptr) {
sAlreadyHandlingTrap = false;
Unwind(JmpBuf);
}
if (HandleTrap(static_cast<CONTEXT*>(context), false)) {
return;
}
// ... and otherwise call the previous signal handler, if one is there
}
struct sigaction* previousSignal = nullptr;
switch (signum) {
case SIGSEGV: previousSignal = &sPrevSIGSEGVHandler; break;

View File

@@ -13,18 +13,12 @@
extern "C" {
#endif
// Record the Trap code and wasm bytecode offset in TLS somewhere
void* RecordTrap(const uint8_t* pc, bool reset_guard_page);
#if defined(_WIN32)
#include <windows.h>
#include <winternl.h>
bool InstanceSignalHandler(LPEXCEPTION_POINTERS);
#elif defined(USE_APPLE_MACH_PORTS)
bool InstanceSignalHandler(int, siginfo_t *, void *);
void* HandleTrap(const uint8_t*, LPEXCEPTION_POINTERS);
#else
#include <sys/ucontext.h>
bool InstanceSignalHandler(int, siginfo_t *, ucontext_t *);
void* HandleTrap(const uint8_t*, int, siginfo_t *, void *);
#endif
void Unwind(void*);

View File

@@ -18,10 +18,9 @@ use crate::vmcontext::{
use memoffset::offset_of;
use more_asserts::assert_lt;
use std::any::Any;
use std::cell::{Cell, RefCell};
use std::cell::Cell;
use std::collections::HashSet;
use std::convert::TryFrom;
use std::ptr::NonNull;
use std::rc::Rc;
use std::sync::Arc;
use std::{mem, ptr, slice};
@@ -33,107 +32,29 @@ use wasmtime_environ::wasm::{
};
use wasmtime_environ::{DataInitializer, Module, TableElements, VMOffsets};
thread_local! {
/// A stack of currently-running `Instance`s, if any.
pub(crate) static CURRENT_INSTANCE: RefCell<Vec<NonNull<Instance>>> = RefCell::new(Vec::new());
}
cfg_if::cfg_if! {
if #[cfg(any(target_os = "linux", target_os = "macos"))] {
pub type SignalHandler = dyn Fn(libc::c_int, *const libc::siginfo_t, *const libc::c_void) -> bool;
pub fn signal_handler_none(
_signum: libc::c_int,
_siginfo: *const libc::siginfo_t,
_context: *const libc::c_void,
) -> bool {
false
}
#[no_mangle]
pub extern "C" fn InstanceSignalHandler(
signum: libc::c_int,
siginfo: *mut libc::siginfo_t,
context: *mut libc::c_void,
) -> bool {
CURRENT_INSTANCE.with(|current_instance| {
let current_instance = current_instance
.try_borrow()
.expect("borrow current instance");
if current_instance.is_empty() {
return false;
} else {
unsafe {
let last = &current_instance
.last()
.expect("current instance not none")
.as_ref();
let f = last
.signal_handler
.replace(Box::new(signal_handler_none));
let ret = f(signum, siginfo, context);
last.signal_handler.set(f);
ret
}
}
})
}
impl InstanceHandle {
/// Set a custom signal handler
pub fn set_signal_handler<H>(&mut self, handler: H)
where
H: 'static + Fn(libc::c_int, *const libc::siginfo_t, *const libc::c_void) -> bool,
{
self.instance().signal_handler.set(Box::new(handler));
self.instance().signal_handler.set(Some(Box::new(handler)));
}
}
} else if #[cfg(target_os = "windows")] {
pub type SignalHandler = dyn Fn(winapi::um::winnt::EXCEPTION_POINTERS) -> bool;
pub fn signal_handler_none(
_exception_info: winapi::um::winnt::EXCEPTION_POINTERS
) -> bool {
false
}
#[no_mangle]
pub extern "C" fn InstanceSignalHandler(
exception_info: winapi::um::winnt::EXCEPTION_POINTERS
) -> bool {
CURRENT_INSTANCE.with(|current_instance| {
let current_instance = current_instance
.try_borrow()
.expect("borrow current instance");
if current_instance.is_empty() {
return false;
} else {
unsafe {
let last = &current_instance
.last()
.expect("current instance not none")
.as_ref();
let f = last
.signal_handler
.replace(Box::new(signal_handler_none));
let ret = f(exception_info);
last.signal_handler.set(f);
ret
}
}
})
}
pub type SignalHandler = dyn Fn(winapi::um::winnt::PEXCEPTION_POINTERS) -> bool;
impl InstanceHandle {
/// Set a custom signal handler
pub fn set_signal_handler<H>(&mut self, handler: H)
where
H: 'static + Fn(winapi::um::winnt::EXCEPTION_POINTERS) -> bool,
H: 'static + Fn(winapi::um::winnt::PEXCEPTION_POINTERS) -> bool,
{
self.instance().signal_handler.set(Box::new(handler));
self.instance().signal_handler.set(Some(Box::new(handler)));
}
}
}
@@ -177,7 +98,7 @@ pub(crate) struct Instance {
dbg_jit_registration: Option<Rc<GdbJitImageRegistration>>,
/// Handler run when `SIGBUS`, `SIGFPE`, `SIGILL`, or `SIGSEGV` are caught by the instance thread.
signal_handler: Cell<Box<SignalHandler>>,
pub(crate) signal_handler: Cell<Option<Box<SignalHandler>>>,
/// Additional context used by compiled wasm code. This field is last, and
/// represents a dynamically-sized array that extends beyond the nominal
@@ -596,28 +517,6 @@ pub struct InstanceHandle {
}
impl InstanceHandle {
#[doc(hidden)]
pub fn with_signals_on<F, R>(&self, action: F) -> R
where
F: FnOnce() -> R,
{
CURRENT_INSTANCE.with(|current_instance| {
current_instance
.borrow_mut()
.push(unsafe { NonNull::new_unchecked(self.instance) });
});
let result = action();
CURRENT_INSTANCE.with(|current_instance| {
let mut current_instance = current_instance.borrow_mut();
assert!(!current_instance.is_empty());
current_instance.pop();
});
result
}
/// Create a new `InstanceHandle` pointing at a new `Instance`.
///
/// # Unsafety
@@ -682,7 +581,7 @@ impl InstanceHandle {
finished_functions,
dbg_jit_registration,
host_state,
signal_handler: Cell::new(Box::new(signal_handler_none)),
signal_handler: Cell::new(None),
vmctx: VMContext {},
};
ptr::write(instance_ptr, instance);
@@ -865,7 +764,7 @@ impl InstanceHandle {
}
/// Return a reference to the contained `Instance`.
fn instance(&self) -> &Instance {
pub(crate) fn instance(&self) -> &Instance {
unsafe { &*(self.instance as *const Instance) }
}
}

View File

@@ -1,6 +1,7 @@
//! WebAssembly trap handling, which is built on top of the lower-level
//! signalhandling mechanisms.
use crate::instance::{InstanceHandle, SignalHandler};
use crate::trap_registry::get_trap_registry;
use crate::trap_registry::TrapDescription;
use crate::vmcontext::{VMContext, VMFunctionBody};
@@ -29,41 +30,40 @@ extern "C" {
fn Unwind(jmp_buf: *const u8) -> !;
}
/// Record the Trap code and wasm bytecode offset in TLS somewhere
#[doc(hidden)]
#[allow(non_snake_case)]
#[no_mangle]
pub extern "C" fn RecordTrap(pc: *const u8, reset_guard_page: bool) -> *const u8 {
tls::with(|info| {
// TODO: stack overflow can happen at any random time (i.e. in malloc()
// in memory.grow) and it's really hard to determine if the cause was
// stack overflow and if it happened in WebAssembly module.
//
// So, let's assume that any untrusted code called from WebAssembly
// doesn't trap. Then, if we have called some WebAssembly code, it
// means the trap is stack overflow.
if info.jmp_buf.get().is_null() {
return ptr::null();
cfg_if::cfg_if! {
if #[cfg(any(target_os = "linux", target_os = "macos"))] {
#[no_mangle]
pub unsafe extern "C" fn HandleTrap(
pc: *mut u8,
signum: libc::c_int,
siginfo: *mut libc::siginfo_t,
context: *mut libc::c_void,
) -> *const u8 {
tls::with(|info| {
match info {
Some(info) => info.handle_trap(pc, false, |handler| handler(signum, siginfo, context)),
None => ptr::null(),
}
})
}
} else if #[cfg(target_os = "windows")] {
use winapi::um::winnt::PEXCEPTION_POINTERS;
use winapi::um::minwinbase::EXCEPTION_STACK_OVERFLOW;
let registry = get_trap_registry();
let trap = Trap::Wasm {
desc: registry
.get_trap(pc as usize)
.unwrap_or_else(|| TrapDescription {
source_loc: ir::SourceLoc::default(),
trap_code: ir::TrapCode::StackOverflow,
}),
backtrace: Backtrace::new_unresolved(),
};
if reset_guard_page {
info.reset_guard_page.set(true);
#[no_mangle]
pub unsafe extern "C" fn HandleTrap(
pc: *mut u8,
exception_info: PEXCEPTION_POINTERS
) -> *const u8 {
tls::with(|info| {
let reset_guard_page = (*(*exception_info).ExceptionRecord).ExceptionCode == EXCEPTION_STACK_OVERFLOW;
match info {
Some(info) => info.handle_trap(pc, reset_guard_page, |handler| handler(exception_info)),
None => ptr::null(),
}
})
}
info.unwind.replace(UnwindReason::Trap(trap));
info.jmp_buf.get()
})
}
}
/// Raises a user-defined trap immediately.
@@ -79,7 +79,7 @@ pub extern "C" fn RecordTrap(pc: *const u8, reset_guard_page: bool) -> *const u8
/// `wasmtime_call_trampoline` must have been previously called.
pub unsafe fn raise_user_trap(data: Box<dyn Error + Send + Sync>) -> ! {
let trap = Trap::User(data);
tls::with(|info| info.unwind_with(UnwindReason::Trap(trap)))
tls::with(|info| info.unwrap().unwind_with(UnwindReason::Trap(trap)))
}
/// Carries a Rust panic across wasm code and resumes the panic on the other
@@ -90,7 +90,7 @@ pub unsafe fn raise_user_trap(data: Box<dyn Error + Send + Sync>) -> ! {
/// Only safe to call when wasm code is on the stack, aka `wasmtime_call` or
/// `wasmtime_call_trampoline` must have been previously called.
pub unsafe fn resume_panic(payload: Box<dyn Any + Send>) -> ! {
tls::with(|info| info.unwind_with(UnwindReason::Panic(payload)))
tls::with(|info| info.unwrap().unwind_with(UnwindReason::Panic(payload)))
}
#[cfg(target_os = "windows")]
@@ -145,8 +145,7 @@ pub unsafe extern "C" fn wasmtime_call_trampoline(
callee: *const VMFunctionBody,
values_vec: *mut u8,
) -> Result<(), Trap> {
let cx = CallThreadState::new();
let ret = tls::set(&cx, || {
CallThreadState::new(vmctx).with(|cx| {
WasmtimeCallTrampoline(
cx.jmp_buf.as_ptr(),
vmctx as *mut u8,
@@ -154,8 +153,7 @@ pub unsafe extern "C" fn wasmtime_call_trampoline(
callee,
values_vec,
)
});
cx.into_result(ret)
})
}
/// Call the wasm function pointed to by `callee`, which has no arguments or
@@ -166,16 +164,14 @@ pub unsafe extern "C" fn wasmtime_call(
caller_vmctx: *mut VMContext,
callee: *const VMFunctionBody,
) -> Result<(), Trap> {
let cx = CallThreadState::new();
let ret = tls::set(&cx, || {
CallThreadState::new(vmctx).with(|cx| {
WasmtimeCall(
cx.jmp_buf.as_ptr(),
vmctx as *mut u8,
caller_vmctx as *mut u8,
callee,
)
});
cx.into_result(ret)
})
}
/// Temporary state stored on the stack which is registered in the `tls` module
@@ -184,6 +180,8 @@ pub struct CallThreadState {
unwind: Cell<UnwindReason>,
jmp_buf: Cell<*const u8>,
reset_guard_page: Cell<bool>,
prev: Option<*const CallThreadState>,
vmctx: *mut VMContext,
}
enum UnwindReason {
@@ -193,27 +191,45 @@ enum UnwindReason {
}
impl CallThreadState {
fn new() -> CallThreadState {
fn new(vmctx: *mut VMContext) -> CallThreadState {
CallThreadState {
unwind: Cell::new(UnwindReason::None),
vmctx,
jmp_buf: Cell::new(ptr::null()),
reset_guard_page: Cell::new(false),
prev: None,
}
}
fn into_result(self, ret: i32) -> Result<(), Trap> {
match self.unwind.replace(UnwindReason::None) {
UnwindReason::None => {
debug_assert_eq!(ret, 1);
Ok(())
fn with(mut self, closure: impl FnOnce(&CallThreadState) -> i32) -> Result<(), Trap> {
tls::with(|prev| {
self.prev = prev.map(|p| p as *const _);
let ret = tls::set(&self, || closure(&self));
match self.unwind.replace(UnwindReason::None) {
UnwindReason::None => {
debug_assert_eq!(ret, 1);
Ok(())
}
UnwindReason::Trap(trap) => {
debug_assert_eq!(ret, 0);
Err(trap)
}
UnwindReason::Panic(panic) => {
debug_assert_eq!(ret, 0);
std::panic::resume_unwind(panic)
}
}
UnwindReason::Trap(trap) => {
debug_assert_eq!(ret, 0);
Err(trap)
})
}
fn any_instance(&self, func: impl Fn(&InstanceHandle) -> bool) -> bool {
unsafe {
if func(&InstanceHandle::from_vmctx(self.vmctx)) {
return true;
}
UnwindReason::Panic(panic) => {
debug_assert_eq!(ret, 0);
std::panic::resume_unwind(panic)
match self.prev {
Some(prev) => (*prev).any_instance(func),
None => false,
}
}
}
@@ -224,6 +240,71 @@ impl CallThreadState {
Unwind(self.jmp_buf.get());
}
}
/// Trap handler using our thread-local state.
///
/// * `pc` - the program counter the trap happened at
/// * `reset_guard_page` - whether or not to reset the guard page,
/// currently Windows specific
/// * `call_handler` - a closure used to invoke the platform-specific
/// signal handler for each instance, if available.
///
/// Attempts to handle the trap if it's a wasm trap. Returns a few
/// different things:
///
/// * null - the trap didn't look like a wasm trap and should continue as a
/// trap
/// * 1 as a pointer - the trap was handled by a custom trap handler on an
/// instance, and the trap handler should quickly return.
/// * a different pointer - a jmp_buf buffer to longjmp to, meaning that
/// the wasm trap was succesfully handled.
fn handle_trap(
&self,
pc: *const u8,
reset_guard_page: bool,
call_handler: impl Fn(&SignalHandler) -> bool,
) -> *const u8 {
// First up see if any instance registered has a custom trap handler,
// in which case run them all. If anything handles the trap then we
// return that the trap was handled.
if self.any_instance(|i| {
let handler = match i.instance().signal_handler.replace(None) {
Some(handler) => handler,
None => return false,
};
let result = call_handler(&handler);
i.instance().signal_handler.set(Some(handler));
return result;
}) {
return 1 as *const _;
}
// TODO: stack overflow can happen at any random time (i.e. in malloc()
// in memory.grow) and it's really hard to determine if the cause was
// stack overflow and if it happened in WebAssembly module.
//
// So, let's assume that any untrusted code called from WebAssembly
// doesn't trap. Then, if we have called some WebAssembly code, it
// means the trap is stack overflow.
if self.jmp_buf.get().is_null() {
return ptr::null();
}
let registry = get_trap_registry();
let trap = Trap::Wasm {
desc: registry
.get_trap(pc as usize)
.unwrap_or_else(|| TrapDescription {
source_loc: ir::SourceLoc::default(),
trap_code: ir::TrapCode::StackOverflow,
}),
backtrace: Backtrace::new_unresolved(),
};
self.reset_guard_page.set(reset_guard_page);
self.unwind.replace(UnwindReason::Trap(trap));
self.jmp_buf.get()
}
}
impl Drop for CallThreadState {
@@ -266,11 +347,10 @@ mod tls {
/// Returns the last pointer configured with `set` above. Panics if `set`
/// has not been previously called.
pub fn with<R>(closure: impl FnOnce(&CallThreadState) -> R) -> R {
pub fn with<R>(closure: impl FnOnce(Option<&CallThreadState>) -> R) -> R {
PTR.with(|ptr| {
let p = ptr.get();
assert!(!p.is_null());
unsafe { closure(&*p) }
unsafe { closure(if p.is_null() { None } else { Some(&*p) }) }
})
}
}