From 05d6c27142cfdd31bc6b837114549974d93204cb Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Thu, 23 Jan 2020 14:34:47 -0600 Subject: [PATCH] Reduce number of thread locals in trap handling (#852) * Reduce number of thread locals in trap handling This commit refactors the trap handling portion of wasmtime with a few goals in mind. I've been reading around a bit lately and feel that we have a bit too few globals and thread locals floating around rather than handles attached to contexts. I'm hoping that we can reduce the number of thread locals and globals, and this commit is the start of reducing this number. The changes applied in this commit remove the set of thread locals in the `traphandlers` module in favor of one thread local that's managed in a sort of stack discipline. This way each call to `wasmtime_call*` sets up its own stack local state that can be managed and read on that stack frame. Additionally the C++ glue code around `setjmp` and `longjmp` has all been refactored to avoid going back and forth between Rust and C++. Now we'll simply enter C++, go straight into `setjmp`/the call, and then traps will enter Rust only once to both learn if the trap should be acted upon and record information about the trap. Overall the hope here is that context passing between `wasmtime_call*` and the trap handling function will be a bit easier. For example I hope to remove the global `get_trap_registry()` function next in favor of storing a handle to a registry inside each instance, and the `*mut VMContext` can be used to reach the `InstanceHandle` underneath, and this trap registry. * Update crates/runtime/src/traphandlers.rs Co-Authored-By: Sergei Pepyakin Co-authored-by: Sergei Pepyakin --- .../runtime/signalhandlers/SignalHandlers.cpp | 12 +- .../runtime/signalhandlers/SignalHandlers.hpp | 8 +- crates/runtime/signalhandlers/Trampolines.cpp | 17 +- crates/runtime/src/traphandlers.rs | 214 +++++++++++------- 4 files changed, 147 insertions(+), 104 deletions(-) diff --git a/crates/runtime/signalhandlers/SignalHandlers.cpp b/crates/runtime/signalhandlers/SignalHandlers.cpp index 40fb8c1f5f..63d46c7689 100644 --- a/crates/runtime/signalhandlers/SignalHandlers.cpp +++ b/crates/runtime/signalhandlers/SignalHandlers.cpp @@ -408,12 +408,11 @@ HandleTrap(CONTEXT* context, bool reset_guard_page) { assert(sAlreadyHandlingTrap); - if (!CheckIfTrapAtAddress(ContextToPC(context))) { - return false; + void *JmpBuf = RecordTrap(ContextToPC(context), reset_guard_page); + if (JmpBuf == nullptr) { + return false; } - RecordTrap(ContextToPC(context), reset_guard_page); - // Unwind calls longjmp, so it doesn't run the automatic // sAlreadhHanldingTrap cleanups, so reset it manually before doing // a longjmp. @@ -423,12 +422,13 @@ HandleTrap(CONTEXT* context, bool reset_guard_page) // Reroute the PC to run the Unwind function on the main stack after the // handler exits. This doesn't yet work for stack overflow traps, because // in that case the main thread doesn't have any space left to run. - SetContextPC(context, reinterpret_cast(&Unwind)); + assert(false); // this branch isn't implemented here + // SetContextPC(context, reinterpret_cast(&Unwind)); #else // For now, just call Unwind directly, rather than redirecting the PC there, // so that it runs on the alternate signal handler stack. To run on the main // stack, reroute the context PC like this: - Unwind(); + Unwind(JmpBuf); #endif return true; diff --git a/crates/runtime/signalhandlers/SignalHandlers.hpp b/crates/runtime/signalhandlers/SignalHandlers.hpp index 5bccd03e5d..623f25f269 100644 --- a/crates/runtime/signalhandlers/SignalHandlers.hpp +++ b/crates/runtime/signalhandlers/SignalHandlers.hpp @@ -13,9 +13,8 @@ extern "C" { #endif -int8_t CheckIfTrapAtAddress(const uint8_t* pc); // Record the Trap code and wasm bytecode offset in TLS somewhere -void RecordTrap(const uint8_t* pc, bool reset_guard_page); +void* RecordTrap(const uint8_t* pc, bool reset_guard_page); #if defined(_WIN32) #include @@ -28,10 +27,7 @@ bool InstanceSignalHandler(int, siginfo_t *, void *); bool InstanceSignalHandler(int, siginfo_t *, ucontext_t *); #endif -void* EnterScope(void*); -void LeaveScope(void*); -void* GetScope(void); -void Unwind(void); +void Unwind(void*); // This function performs the low-overhead signal handler initialization that we // want to do eagerly to ensure a more-deterministic global process state. This diff --git a/crates/runtime/signalhandlers/Trampolines.cpp b/crates/runtime/signalhandlers/Trampolines.cpp index d74f094280..e554823652 100644 --- a/crates/runtime/signalhandlers/Trampolines.cpp +++ b/crates/runtime/signalhandlers/Trampolines.cpp @@ -4,39 +4,34 @@ extern "C" int WasmtimeCallTrampoline( + void **buf_storage, void *vmctx, void *caller_vmctx, void (*body)(void*, void*, void*), void *args) { jmp_buf buf; - void *volatile prev; if (setjmp(buf) != 0) { - LeaveScope(prev); return 0; } - prev = EnterScope(&buf); + *buf_storage = &buf; body(vmctx, caller_vmctx, args); - LeaveScope(prev); return 1; } extern "C" -int WasmtimeCall(void *vmctx, void *caller_vmctx, void (*body)(void*, void*)) { +int WasmtimeCall(void **buf_storage, void *vmctx, void *caller_vmctx, void (*body)(void*, void*)) { jmp_buf buf; - void *volatile prev; if (setjmp(buf) != 0) { - LeaveScope(prev); return 0; } - prev = EnterScope(&buf); + *buf_storage = &buf; body(vmctx, caller_vmctx); - LeaveScope(prev); return 1; } extern "C" -void Unwind() { - jmp_buf *buf = (jmp_buf*) GetScope(); +void Unwind(void *JmpBuf) { + jmp_buf *buf = (jmp_buf*) JmpBuf; longjmp(*buf, 1); } diff --git a/crates/runtime/src/traphandlers.rs b/crates/runtime/src/traphandlers.rs index afc15282fb..bc05e1cbf1 100644 --- a/crates/runtime/src/traphandlers.rs +++ b/crates/runtime/src/traphandlers.rs @@ -12,89 +12,60 @@ use wasmtime_environ::ir; extern "C" { fn WasmtimeCallTrampoline( + jmp_buf: *mut *const u8, vmctx: *mut u8, caller_vmctx: *mut u8, callee: *const VMFunctionBody, values_vec: *mut u8, ) -> i32; - fn WasmtimeCall(vmctx: *mut u8, caller_vmctx: *mut u8, callee: *const VMFunctionBody) -> i32; -} - -thread_local! { - static RECORDED_TRAP: Cell> = Cell::new(None); - static JMP_BUF: Cell<*const u8> = Cell::new(ptr::null()); - static RESET_GUARD_PAGE: Cell = Cell::new(false); -} - -/// Check if there is a trap at given PC -#[doc(hidden)] -#[allow(non_snake_case)] -#[no_mangle] -pub extern "C" fn CheckIfTrapAtAddress(_pc: *const u8) -> i8 { - // TODO: stack overflow can happen at any random time (i.e. in malloc() in memory.grow) - // and it's really hard to determine if the cause was stack overflow and if it happened - // in WebAssembly module. - // So, let's assume that any untrusted code called from WebAssembly doesn't trap. - // Then, if we have called some WebAssembly code, it means the trap is stack overflow. - JMP_BUF.with(|ptr| !ptr.get().is_null()) as i8 + fn WasmtimeCall( + jmp_buf: *mut *const u8, + vmctx: *mut u8, + caller_vmctx: *mut u8, + callee: *const VMFunctionBody, + ) -> i32; } /// Record the Trap code and wasm bytecode offset in TLS somewhere #[doc(hidden)] #[allow(non_snake_case)] #[no_mangle] -pub extern "C" fn RecordTrap(pc: *const u8, reset_guard_page: bool) { - // TODO: please see explanation in CheckIfTrapAtAddress. - let registry = get_trap_registry(); - let trap = Trap { - desc: registry - .get_trap(pc as usize) - .unwrap_or_else(|| TrapDescription { - source_loc: ir::SourceLoc::default(), - trap_code: ir::TrapCode::StackOverflow, - }), - backtrace: Backtrace::new_unresolved(), - }; +pub extern "C" fn RecordTrap(pc: *const u8, reset_guard_page: bool) -> *const u8 { + tls::with(|info| { + // TODO: stack overflow can happen at any random time (i.e. in malloc() + // in memory.grow) and it's really hard to determine if the cause was + // stack overflow and if it happened in WebAssembly module. + // + // So, let's assume that any untrusted code called from WebAssembly + // doesn't trap. Then, if we have called some WebAssembly code, it + // means the trap is stack overflow. + if info.jmp_buf.get().is_null() { + return ptr::null(); + } - if reset_guard_page { - RESET_GUARD_PAGE.with(|v| v.set(true)); - } + let registry = get_trap_registry(); + let trap = Trap { + desc: registry + .get_trap(pc as usize) + .unwrap_or_else(|| TrapDescription { + source_loc: ir::SourceLoc::default(), + trap_code: ir::TrapCode::StackOverflow, + }), + backtrace: Backtrace::new_unresolved(), + }; - RECORDED_TRAP.with(|data| { - let prev = data.replace(Some(trap)); + if reset_guard_page { + info.reset_guard_page.set(true); + } + + let prev = info.trap.replace(Some(trap)); assert!( prev.is_none(), "Only one trap per thread can be recorded at a moment!" ); - }); -} -#[doc(hidden)] -#[allow(non_snake_case)] -#[no_mangle] -pub extern "C" fn EnterScope(ptr: *const u8) -> *const u8 { - JMP_BUF.with(|buf| buf.replace(ptr)) -} - -#[doc(hidden)] -#[allow(non_snake_case)] -#[no_mangle] -pub extern "C" fn GetScope() -> *const u8 { - JMP_BUF.with(|buf| buf.get()) -} - -#[doc(hidden)] -#[allow(non_snake_case)] -#[no_mangle] -pub extern "C" fn LeaveScope(ptr: *const u8) { - RESET_GUARD_PAGE.with(|v| { - if v.get() { - reset_guard_page(); - v.set(false); - } - }); - - JMP_BUF.with(|buf| buf.set(ptr)) + info.jmp_buf.get() + }) } #[cfg(target_os = "windows")] @@ -136,12 +107,6 @@ impl fmt::Display for Trap { impl std::error::Error for Trap {} -fn last_trap() -> Trap { - RECORDED_TRAP - .with(|data| data.replace(None)) - .expect("trap_message must be called after trap occurred") -} - fn trap_code_to_expected_string(trap_code: ir::TrapCode) -> String { use ir::TrapCode::*; match trap_code { @@ -170,14 +135,19 @@ pub unsafe extern "C" fn wasmtime_call_trampoline( callee: *const VMFunctionBody, values_vec: *mut u8, ) -> Result<(), Trap> { - if WasmtimeCallTrampoline( - vmctx as *mut u8, - caller_vmctx as *mut u8, - callee, - values_vec, - ) == 0 - { - Err(last_trap()) + let cx = CallThreadState::new(); + let ret = tls::set(&cx, || { + WasmtimeCallTrampoline( + cx.jmp_buf.as_ptr(), + vmctx as *mut u8, + caller_vmctx as *mut u8, + callee, + values_vec, + ) + }); + + if ret == 0 { + Err(cx.unwrap_trap()) } else { Ok(()) } @@ -191,9 +161,91 @@ pub unsafe extern "C" fn wasmtime_call( caller_vmctx: *mut VMContext, callee: *const VMFunctionBody, ) -> Result<(), Trap> { - if WasmtimeCall(vmctx as *mut u8, caller_vmctx as *mut u8, callee) == 0 { - Err(last_trap()) + let cx = CallThreadState::new(); + let ret = tls::set(&cx, || { + WasmtimeCall( + cx.jmp_buf.as_ptr(), + vmctx as *mut u8, + caller_vmctx as *mut u8, + callee, + ) + }); + if ret == 0 { + Err(cx.unwrap_trap()) } else { Ok(()) } } + +/// Temporary state stored on the stack which is registered in the `tls` module +/// below for calls into wasm. +pub struct CallThreadState { + trap: Cell>, + jmp_buf: Cell<*const u8>, + reset_guard_page: Cell, +} + +impl CallThreadState { + fn new() -> CallThreadState { + CallThreadState { + trap: Cell::new(None), + jmp_buf: Cell::new(ptr::null()), + reset_guard_page: Cell::new(false), + } + } + + fn unwrap_trap(self) -> Trap { + self.trap + .replace(None) + .expect("unwrap_trap must be called after trap occurred") + } +} + +impl Drop for CallThreadState { + fn drop(&mut self) { + if self.reset_guard_page.get() { + reset_guard_page(); + } + } +} + +// A private inner module for managing the TLS state that we require across +// calls in wasm. The WebAssembly code is called from C++ and then a trap may +// happen which requires us to read some contextual state to figure out what to +// do with the trap. This `tls` module is used to persist that information from +// the caller to the trap site. +mod tls { + use super::CallThreadState; + use std::cell::Cell; + use std::ptr; + + thread_local!(static PTR: Cell<*const CallThreadState> = Cell::new(ptr::null())); + + /// Configures thread local state such that for the duration of the + /// execution of `closure` any call to `with` will yield `ptr`, unless this + /// is recursively called again. + pub fn set(ptr: &CallThreadState, closure: impl FnOnce() -> R) -> R { + struct Reset<'a, T: Copy>(&'a Cell, T); + + impl Drop for Reset<'_, T> { + fn drop(&mut self) { + self.0.set(self.1); + } + } + + PTR.with(|p| { + let _r = Reset(p, p.replace(ptr)); + closure() + }) + } + + /// Returns the last pointer configured with `set` above. Panics if `set` + /// has not been previously called. + pub fn with(closure: impl FnOnce(&CallThreadState) -> R) -> R { + PTR.with(|ptr| { + let p = ptr.get(); + assert!(!p.is_null()); + unsafe { closure(&*p) } + }) + } +}