Reduce number of thread locals in trap handling (#852)
* Reduce number of thread locals in trap handling This commit refactors the trap handling portion of wasmtime with a few goals in mind. I've been reading around a bit lately and feel that we have a bit too few globals and thread locals floating around rather than handles attached to contexts. I'm hoping that we can reduce the number of thread locals and globals, and this commit is the start of reducing this number. The changes applied in this commit remove the set of thread locals in the `traphandlers` module in favor of one thread local that's managed in a sort of stack discipline. This way each call to `wasmtime_call*` sets up its own stack local state that can be managed and read on that stack frame. Additionally the C++ glue code around `setjmp` and `longjmp` has all been refactored to avoid going back and forth between Rust and C++. Now we'll simply enter C++, go straight into `setjmp`/the call, and then traps will enter Rust only once to both learn if the trap should be acted upon and record information about the trap. Overall the hope here is that context passing between `wasmtime_call*` and the trap handling function will be a bit easier. For example I hope to remove the global `get_trap_registry()` function next in favor of storing a handle to a registry inside each instance, and the `*mut VMContext` can be used to reach the `InstanceHandle` underneath, and this trap registry. * Update crates/runtime/src/traphandlers.rs Co-Authored-By: Sergei Pepyakin <s.pepyakin@gmail.com> Co-authored-by: Sergei Pepyakin <s.pepyakin@gmail.com>
This commit is contained in:
@@ -408,12 +408,11 @@ HandleTrap(CONTEXT* context, bool reset_guard_page)
|
|||||||
{
|
{
|
||||||
assert(sAlreadyHandlingTrap);
|
assert(sAlreadyHandlingTrap);
|
||||||
|
|
||||||
if (!CheckIfTrapAtAddress(ContextToPC(context))) {
|
void *JmpBuf = RecordTrap(ContextToPC(context), reset_guard_page);
|
||||||
return false;
|
if (JmpBuf == nullptr) {
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
RecordTrap(ContextToPC(context), reset_guard_page);
|
|
||||||
|
|
||||||
// Unwind calls longjmp, so it doesn't run the automatic
|
// Unwind calls longjmp, so it doesn't run the automatic
|
||||||
// sAlreadhHanldingTrap cleanups, so reset it manually before doing
|
// sAlreadhHanldingTrap cleanups, so reset it manually before doing
|
||||||
// a longjmp.
|
// a longjmp.
|
||||||
@@ -423,12 +422,13 @@ HandleTrap(CONTEXT* context, bool reset_guard_page)
|
|||||||
// Reroute the PC to run the Unwind function on the main stack after the
|
// Reroute the PC to run the Unwind function on the main stack after the
|
||||||
// handler exits. This doesn't yet work for stack overflow traps, because
|
// handler exits. This doesn't yet work for stack overflow traps, because
|
||||||
// in that case the main thread doesn't have any space left to run.
|
// in that case the main thread doesn't have any space left to run.
|
||||||
SetContextPC(context, reinterpret_cast<const uint8_t*>(&Unwind));
|
assert(false); // this branch isn't implemented here
|
||||||
|
// SetContextPC(context, reinterpret_cast<const uint8_t*>(&Unwind));
|
||||||
#else
|
#else
|
||||||
// For now, just call Unwind directly, rather than redirecting the PC there,
|
// For now, just call Unwind directly, rather than redirecting the PC there,
|
||||||
// so that it runs on the alternate signal handler stack. To run on the main
|
// so that it runs on the alternate signal handler stack. To run on the main
|
||||||
// stack, reroute the context PC like this:
|
// stack, reroute the context PC like this:
|
||||||
Unwind();
|
Unwind(JmpBuf);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
|||||||
@@ -13,9 +13,8 @@
|
|||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
int8_t CheckIfTrapAtAddress(const uint8_t* pc);
|
|
||||||
// Record the Trap code and wasm bytecode offset in TLS somewhere
|
// Record the Trap code and wasm bytecode offset in TLS somewhere
|
||||||
void RecordTrap(const uint8_t* pc, bool reset_guard_page);
|
void* RecordTrap(const uint8_t* pc, bool reset_guard_page);
|
||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
#include <windows.h>
|
#include <windows.h>
|
||||||
@@ -28,10 +27,7 @@ bool InstanceSignalHandler(int, siginfo_t *, void *);
|
|||||||
bool InstanceSignalHandler(int, siginfo_t *, ucontext_t *);
|
bool InstanceSignalHandler(int, siginfo_t *, ucontext_t *);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void* EnterScope(void*);
|
void Unwind(void*);
|
||||||
void LeaveScope(void*);
|
|
||||||
void* GetScope(void);
|
|
||||||
void Unwind(void);
|
|
||||||
|
|
||||||
// This function performs the low-overhead signal handler initialization that we
|
// This function performs the low-overhead signal handler initialization that we
|
||||||
// want to do eagerly to ensure a more-deterministic global process state. This
|
// want to do eagerly to ensure a more-deterministic global process state. This
|
||||||
|
|||||||
@@ -4,39 +4,34 @@
|
|||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
int WasmtimeCallTrampoline(
|
int WasmtimeCallTrampoline(
|
||||||
|
void **buf_storage,
|
||||||
void *vmctx,
|
void *vmctx,
|
||||||
void *caller_vmctx,
|
void *caller_vmctx,
|
||||||
void (*body)(void*, void*, void*),
|
void (*body)(void*, void*, void*),
|
||||||
void *args)
|
void *args)
|
||||||
{
|
{
|
||||||
jmp_buf buf;
|
jmp_buf buf;
|
||||||
void *volatile prev;
|
|
||||||
if (setjmp(buf) != 0) {
|
if (setjmp(buf) != 0) {
|
||||||
LeaveScope(prev);
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
prev = EnterScope(&buf);
|
*buf_storage = &buf;
|
||||||
body(vmctx, caller_vmctx, args);
|
body(vmctx, caller_vmctx, args);
|
||||||
LeaveScope(prev);
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
int WasmtimeCall(void *vmctx, void *caller_vmctx, void (*body)(void*, void*)) {
|
int WasmtimeCall(void **buf_storage, void *vmctx, void *caller_vmctx, void (*body)(void*, void*)) {
|
||||||
jmp_buf buf;
|
jmp_buf buf;
|
||||||
void *volatile prev;
|
|
||||||
if (setjmp(buf) != 0) {
|
if (setjmp(buf) != 0) {
|
||||||
LeaveScope(prev);
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
prev = EnterScope(&buf);
|
*buf_storage = &buf;
|
||||||
body(vmctx, caller_vmctx);
|
body(vmctx, caller_vmctx);
|
||||||
LeaveScope(prev);
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
void Unwind() {
|
void Unwind(void *JmpBuf) {
|
||||||
jmp_buf *buf = (jmp_buf*) GetScope();
|
jmp_buf *buf = (jmp_buf*) JmpBuf;
|
||||||
longjmp(*buf, 1);
|
longjmp(*buf, 1);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,89 +12,60 @@ use wasmtime_environ::ir;
|
|||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
fn WasmtimeCallTrampoline(
|
fn WasmtimeCallTrampoline(
|
||||||
|
jmp_buf: *mut *const u8,
|
||||||
vmctx: *mut u8,
|
vmctx: *mut u8,
|
||||||
caller_vmctx: *mut u8,
|
caller_vmctx: *mut u8,
|
||||||
callee: *const VMFunctionBody,
|
callee: *const VMFunctionBody,
|
||||||
values_vec: *mut u8,
|
values_vec: *mut u8,
|
||||||
) -> i32;
|
) -> i32;
|
||||||
fn WasmtimeCall(vmctx: *mut u8, caller_vmctx: *mut u8, callee: *const VMFunctionBody) -> i32;
|
fn WasmtimeCall(
|
||||||
}
|
jmp_buf: *mut *const u8,
|
||||||
|
vmctx: *mut u8,
|
||||||
thread_local! {
|
caller_vmctx: *mut u8,
|
||||||
static RECORDED_TRAP: Cell<Option<Trap>> = Cell::new(None);
|
callee: *const VMFunctionBody,
|
||||||
static JMP_BUF: Cell<*const u8> = Cell::new(ptr::null());
|
) -> i32;
|
||||||
static RESET_GUARD_PAGE: Cell<bool> = Cell::new(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if there is a trap at given PC
|
|
||||||
#[doc(hidden)]
|
|
||||||
#[allow(non_snake_case)]
|
|
||||||
#[no_mangle]
|
|
||||||
pub extern "C" fn CheckIfTrapAtAddress(_pc: *const u8) -> i8 {
|
|
||||||
// TODO: stack overflow can happen at any random time (i.e. in malloc() in memory.grow)
|
|
||||||
// and it's really hard to determine if the cause was stack overflow and if it happened
|
|
||||||
// in WebAssembly module.
|
|
||||||
// So, let's assume that any untrusted code called from WebAssembly doesn't trap.
|
|
||||||
// Then, if we have called some WebAssembly code, it means the trap is stack overflow.
|
|
||||||
JMP_BUF.with(|ptr| !ptr.get().is_null()) as i8
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Record the Trap code and wasm bytecode offset in TLS somewhere
|
/// Record the Trap code and wasm bytecode offset in TLS somewhere
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
#[no_mangle]
|
#[no_mangle]
|
||||||
pub extern "C" fn RecordTrap(pc: *const u8, reset_guard_page: bool) {
|
pub extern "C" fn RecordTrap(pc: *const u8, reset_guard_page: bool) -> *const u8 {
|
||||||
// TODO: please see explanation in CheckIfTrapAtAddress.
|
tls::with(|info| {
|
||||||
let registry = get_trap_registry();
|
// TODO: stack overflow can happen at any random time (i.e. in malloc()
|
||||||
let trap = Trap {
|
// in memory.grow) and it's really hard to determine if the cause was
|
||||||
desc: registry
|
// stack overflow and if it happened in WebAssembly module.
|
||||||
.get_trap(pc as usize)
|
//
|
||||||
.unwrap_or_else(|| TrapDescription {
|
// So, let's assume that any untrusted code called from WebAssembly
|
||||||
source_loc: ir::SourceLoc::default(),
|
// doesn't trap. Then, if we have called some WebAssembly code, it
|
||||||
trap_code: ir::TrapCode::StackOverflow,
|
// means the trap is stack overflow.
|
||||||
}),
|
if info.jmp_buf.get().is_null() {
|
||||||
backtrace: Backtrace::new_unresolved(),
|
return ptr::null();
|
||||||
};
|
}
|
||||||
|
|
||||||
if reset_guard_page {
|
let registry = get_trap_registry();
|
||||||
RESET_GUARD_PAGE.with(|v| v.set(true));
|
let trap = Trap {
|
||||||
}
|
desc: registry
|
||||||
|
.get_trap(pc as usize)
|
||||||
|
.unwrap_or_else(|| TrapDescription {
|
||||||
|
source_loc: ir::SourceLoc::default(),
|
||||||
|
trap_code: ir::TrapCode::StackOverflow,
|
||||||
|
}),
|
||||||
|
backtrace: Backtrace::new_unresolved(),
|
||||||
|
};
|
||||||
|
|
||||||
RECORDED_TRAP.with(|data| {
|
if reset_guard_page {
|
||||||
let prev = data.replace(Some(trap));
|
info.reset_guard_page.set(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
let prev = info.trap.replace(Some(trap));
|
||||||
assert!(
|
assert!(
|
||||||
prev.is_none(),
|
prev.is_none(),
|
||||||
"Only one trap per thread can be recorded at a moment!"
|
"Only one trap per thread can be recorded at a moment!"
|
||||||
);
|
);
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[doc(hidden)]
|
info.jmp_buf.get()
|
||||||
#[allow(non_snake_case)]
|
})
|
||||||
#[no_mangle]
|
|
||||||
pub extern "C" fn EnterScope(ptr: *const u8) -> *const u8 {
|
|
||||||
JMP_BUF.with(|buf| buf.replace(ptr))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[doc(hidden)]
|
|
||||||
#[allow(non_snake_case)]
|
|
||||||
#[no_mangle]
|
|
||||||
pub extern "C" fn GetScope() -> *const u8 {
|
|
||||||
JMP_BUF.with(|buf| buf.get())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[doc(hidden)]
|
|
||||||
#[allow(non_snake_case)]
|
|
||||||
#[no_mangle]
|
|
||||||
pub extern "C" fn LeaveScope(ptr: *const u8) {
|
|
||||||
RESET_GUARD_PAGE.with(|v| {
|
|
||||||
if v.get() {
|
|
||||||
reset_guard_page();
|
|
||||||
v.set(false);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
JMP_BUF.with(|buf| buf.set(ptr))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(target_os = "windows")]
|
#[cfg(target_os = "windows")]
|
||||||
@@ -136,12 +107,6 @@ impl fmt::Display for Trap {
|
|||||||
|
|
||||||
impl std::error::Error for Trap {}
|
impl std::error::Error for Trap {}
|
||||||
|
|
||||||
fn last_trap() -> Trap {
|
|
||||||
RECORDED_TRAP
|
|
||||||
.with(|data| data.replace(None))
|
|
||||||
.expect("trap_message must be called after trap occurred")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn trap_code_to_expected_string(trap_code: ir::TrapCode) -> String {
|
fn trap_code_to_expected_string(trap_code: ir::TrapCode) -> String {
|
||||||
use ir::TrapCode::*;
|
use ir::TrapCode::*;
|
||||||
match trap_code {
|
match trap_code {
|
||||||
@@ -170,14 +135,19 @@ pub unsafe extern "C" fn wasmtime_call_trampoline(
|
|||||||
callee: *const VMFunctionBody,
|
callee: *const VMFunctionBody,
|
||||||
values_vec: *mut u8,
|
values_vec: *mut u8,
|
||||||
) -> Result<(), Trap> {
|
) -> Result<(), Trap> {
|
||||||
if WasmtimeCallTrampoline(
|
let cx = CallThreadState::new();
|
||||||
vmctx as *mut u8,
|
let ret = tls::set(&cx, || {
|
||||||
caller_vmctx as *mut u8,
|
WasmtimeCallTrampoline(
|
||||||
callee,
|
cx.jmp_buf.as_ptr(),
|
||||||
values_vec,
|
vmctx as *mut u8,
|
||||||
) == 0
|
caller_vmctx as *mut u8,
|
||||||
{
|
callee,
|
||||||
Err(last_trap())
|
values_vec,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
if ret == 0 {
|
||||||
|
Err(cx.unwrap_trap())
|
||||||
} else {
|
} else {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -191,9 +161,91 @@ pub unsafe extern "C" fn wasmtime_call(
|
|||||||
caller_vmctx: *mut VMContext,
|
caller_vmctx: *mut VMContext,
|
||||||
callee: *const VMFunctionBody,
|
callee: *const VMFunctionBody,
|
||||||
) -> Result<(), Trap> {
|
) -> Result<(), Trap> {
|
||||||
if WasmtimeCall(vmctx as *mut u8, caller_vmctx as *mut u8, callee) == 0 {
|
let cx = CallThreadState::new();
|
||||||
Err(last_trap())
|
let ret = tls::set(&cx, || {
|
||||||
|
WasmtimeCall(
|
||||||
|
cx.jmp_buf.as_ptr(),
|
||||||
|
vmctx as *mut u8,
|
||||||
|
caller_vmctx as *mut u8,
|
||||||
|
callee,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
if ret == 0 {
|
||||||
|
Err(cx.unwrap_trap())
|
||||||
} else {
|
} else {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Temporary state stored on the stack which is registered in the `tls` module
|
||||||
|
/// below for calls into wasm.
|
||||||
|
pub struct CallThreadState {
|
||||||
|
trap: Cell<Option<Trap>>,
|
||||||
|
jmp_buf: Cell<*const u8>,
|
||||||
|
reset_guard_page: Cell<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CallThreadState {
|
||||||
|
fn new() -> CallThreadState {
|
||||||
|
CallThreadState {
|
||||||
|
trap: Cell::new(None),
|
||||||
|
jmp_buf: Cell::new(ptr::null()),
|
||||||
|
reset_guard_page: Cell::new(false),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unwrap_trap(self) -> Trap {
|
||||||
|
self.trap
|
||||||
|
.replace(None)
|
||||||
|
.expect("unwrap_trap must be called after trap occurred")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for CallThreadState {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if self.reset_guard_page.get() {
|
||||||
|
reset_guard_page();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A private inner module for managing the TLS state that we require across
|
||||||
|
// calls in wasm. The WebAssembly code is called from C++ and then a trap may
|
||||||
|
// happen which requires us to read some contextual state to figure out what to
|
||||||
|
// do with the trap. This `tls` module is used to persist that information from
|
||||||
|
// the caller to the trap site.
|
||||||
|
mod tls {
|
||||||
|
use super::CallThreadState;
|
||||||
|
use std::cell::Cell;
|
||||||
|
use std::ptr;
|
||||||
|
|
||||||
|
thread_local!(static PTR: Cell<*const CallThreadState> = Cell::new(ptr::null()));
|
||||||
|
|
||||||
|
/// Configures thread local state such that for the duration of the
|
||||||
|
/// execution of `closure` any call to `with` will yield `ptr`, unless this
|
||||||
|
/// is recursively called again.
|
||||||
|
pub fn set<R>(ptr: &CallThreadState, closure: impl FnOnce() -> R) -> R {
|
||||||
|
struct Reset<'a, T: Copy>(&'a Cell<T>, T);
|
||||||
|
|
||||||
|
impl<T: Copy> Drop for Reset<'_, T> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.0.set(self.1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PTR.with(|p| {
|
||||||
|
let _r = Reset(p, p.replace(ptr));
|
||||||
|
closure()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the last pointer configured with `set` above. Panics if `set`
|
||||||
|
/// has not been previously called.
|
||||||
|
pub fn with<R>(closure: impl FnOnce(&CallThreadState) -> R) -> R {
|
||||||
|
PTR.with(|ptr| {
|
||||||
|
let p = ptr.get();
|
||||||
|
assert!(!p.is_null());
|
||||||
|
unsafe { closure(&*p) }
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user