diff --git a/crates/runtime/signalhandlers/SignalHandlers.cpp b/crates/runtime/signalhandlers/SignalHandlers.cpp index 40fb8c1f5f..63d46c7689 100644 --- a/crates/runtime/signalhandlers/SignalHandlers.cpp +++ b/crates/runtime/signalhandlers/SignalHandlers.cpp @@ -408,12 +408,11 @@ HandleTrap(CONTEXT* context, bool reset_guard_page) { assert(sAlreadyHandlingTrap); - if (!CheckIfTrapAtAddress(ContextToPC(context))) { - return false; + void *JmpBuf = RecordTrap(ContextToPC(context), reset_guard_page); + if (JmpBuf == nullptr) { + return false; } - RecordTrap(ContextToPC(context), reset_guard_page); - // Unwind calls longjmp, so it doesn't run the automatic // sAlreadhHanldingTrap cleanups, so reset it manually before doing // a longjmp. @@ -423,12 +422,13 @@ HandleTrap(CONTEXT* context, bool reset_guard_page) // Reroute the PC to run the Unwind function on the main stack after the // handler exits. This doesn't yet work for stack overflow traps, because // in that case the main thread doesn't have any space left to run. - SetContextPC(context, reinterpret_cast(&Unwind)); + assert(false); // this branch isn't implemented here + // SetContextPC(context, reinterpret_cast(&Unwind)); #else // For now, just call Unwind directly, rather than redirecting the PC there, // so that it runs on the alternate signal handler stack. To run on the main // stack, reroute the context PC like this: - Unwind(); + Unwind(JmpBuf); #endif return true; diff --git a/crates/runtime/signalhandlers/SignalHandlers.hpp b/crates/runtime/signalhandlers/SignalHandlers.hpp index 5bccd03e5d..623f25f269 100644 --- a/crates/runtime/signalhandlers/SignalHandlers.hpp +++ b/crates/runtime/signalhandlers/SignalHandlers.hpp @@ -13,9 +13,8 @@ extern "C" { #endif -int8_t CheckIfTrapAtAddress(const uint8_t* pc); // Record the Trap code and wasm bytecode offset in TLS somewhere -void RecordTrap(const uint8_t* pc, bool reset_guard_page); +void* RecordTrap(const uint8_t* pc, bool reset_guard_page); #if defined(_WIN32) #include @@ -28,10 +27,7 @@ bool InstanceSignalHandler(int, siginfo_t *, void *); bool InstanceSignalHandler(int, siginfo_t *, ucontext_t *); #endif -void* EnterScope(void*); -void LeaveScope(void*); -void* GetScope(void); -void Unwind(void); +void Unwind(void*); // This function performs the low-overhead signal handler initialization that we // want to do eagerly to ensure a more-deterministic global process state. This diff --git a/crates/runtime/signalhandlers/Trampolines.cpp b/crates/runtime/signalhandlers/Trampolines.cpp index d74f094280..e554823652 100644 --- a/crates/runtime/signalhandlers/Trampolines.cpp +++ b/crates/runtime/signalhandlers/Trampolines.cpp @@ -4,39 +4,34 @@ extern "C" int WasmtimeCallTrampoline( + void **buf_storage, void *vmctx, void *caller_vmctx, void (*body)(void*, void*, void*), void *args) { jmp_buf buf; - void *volatile prev; if (setjmp(buf) != 0) { - LeaveScope(prev); return 0; } - prev = EnterScope(&buf); + *buf_storage = &buf; body(vmctx, caller_vmctx, args); - LeaveScope(prev); return 1; } extern "C" -int WasmtimeCall(void *vmctx, void *caller_vmctx, void (*body)(void*, void*)) { +int WasmtimeCall(void **buf_storage, void *vmctx, void *caller_vmctx, void (*body)(void*, void*)) { jmp_buf buf; - void *volatile prev; if (setjmp(buf) != 0) { - LeaveScope(prev); return 0; } - prev = EnterScope(&buf); + *buf_storage = &buf; body(vmctx, caller_vmctx); - LeaveScope(prev); return 1; } extern "C" -void Unwind() { - jmp_buf *buf = (jmp_buf*) GetScope(); +void Unwind(void *JmpBuf) { + jmp_buf *buf = (jmp_buf*) JmpBuf; longjmp(*buf, 1); } diff --git a/crates/runtime/src/traphandlers.rs b/crates/runtime/src/traphandlers.rs index afc15282fb..bc05e1cbf1 100644 --- a/crates/runtime/src/traphandlers.rs +++ b/crates/runtime/src/traphandlers.rs @@ -12,89 +12,60 @@ use wasmtime_environ::ir; extern "C" { fn WasmtimeCallTrampoline( + jmp_buf: *mut *const u8, vmctx: *mut u8, caller_vmctx: *mut u8, callee: *const VMFunctionBody, values_vec: *mut u8, ) -> i32; - fn WasmtimeCall(vmctx: *mut u8, caller_vmctx: *mut u8, callee: *const VMFunctionBody) -> i32; -} - -thread_local! { - static RECORDED_TRAP: Cell> = Cell::new(None); - static JMP_BUF: Cell<*const u8> = Cell::new(ptr::null()); - static RESET_GUARD_PAGE: Cell = Cell::new(false); -} - -/// Check if there is a trap at given PC -#[doc(hidden)] -#[allow(non_snake_case)] -#[no_mangle] -pub extern "C" fn CheckIfTrapAtAddress(_pc: *const u8) -> i8 { - // TODO: stack overflow can happen at any random time (i.e. in malloc() in memory.grow) - // and it's really hard to determine if the cause was stack overflow and if it happened - // in WebAssembly module. - // So, let's assume that any untrusted code called from WebAssembly doesn't trap. - // Then, if we have called some WebAssembly code, it means the trap is stack overflow. - JMP_BUF.with(|ptr| !ptr.get().is_null()) as i8 + fn WasmtimeCall( + jmp_buf: *mut *const u8, + vmctx: *mut u8, + caller_vmctx: *mut u8, + callee: *const VMFunctionBody, + ) -> i32; } /// Record the Trap code and wasm bytecode offset in TLS somewhere #[doc(hidden)] #[allow(non_snake_case)] #[no_mangle] -pub extern "C" fn RecordTrap(pc: *const u8, reset_guard_page: bool) { - // TODO: please see explanation in CheckIfTrapAtAddress. - let registry = get_trap_registry(); - let trap = Trap { - desc: registry - .get_trap(pc as usize) - .unwrap_or_else(|| TrapDescription { - source_loc: ir::SourceLoc::default(), - trap_code: ir::TrapCode::StackOverflow, - }), - backtrace: Backtrace::new_unresolved(), - }; +pub extern "C" fn RecordTrap(pc: *const u8, reset_guard_page: bool) -> *const u8 { + tls::with(|info| { + // TODO: stack overflow can happen at any random time (i.e. in malloc() + // in memory.grow) and it's really hard to determine if the cause was + // stack overflow and if it happened in WebAssembly module. + // + // So, let's assume that any untrusted code called from WebAssembly + // doesn't trap. Then, if we have called some WebAssembly code, it + // means the trap is stack overflow. + if info.jmp_buf.get().is_null() { + return ptr::null(); + } - if reset_guard_page { - RESET_GUARD_PAGE.with(|v| v.set(true)); - } + let registry = get_trap_registry(); + let trap = Trap { + desc: registry + .get_trap(pc as usize) + .unwrap_or_else(|| TrapDescription { + source_loc: ir::SourceLoc::default(), + trap_code: ir::TrapCode::StackOverflow, + }), + backtrace: Backtrace::new_unresolved(), + }; - RECORDED_TRAP.with(|data| { - let prev = data.replace(Some(trap)); + if reset_guard_page { + info.reset_guard_page.set(true); + } + + let prev = info.trap.replace(Some(trap)); assert!( prev.is_none(), "Only one trap per thread can be recorded at a moment!" ); - }); -} -#[doc(hidden)] -#[allow(non_snake_case)] -#[no_mangle] -pub extern "C" fn EnterScope(ptr: *const u8) -> *const u8 { - JMP_BUF.with(|buf| buf.replace(ptr)) -} - -#[doc(hidden)] -#[allow(non_snake_case)] -#[no_mangle] -pub extern "C" fn GetScope() -> *const u8 { - JMP_BUF.with(|buf| buf.get()) -} - -#[doc(hidden)] -#[allow(non_snake_case)] -#[no_mangle] -pub extern "C" fn LeaveScope(ptr: *const u8) { - RESET_GUARD_PAGE.with(|v| { - if v.get() { - reset_guard_page(); - v.set(false); - } - }); - - JMP_BUF.with(|buf| buf.set(ptr)) + info.jmp_buf.get() + }) } #[cfg(target_os = "windows")] @@ -136,12 +107,6 @@ impl fmt::Display for Trap { impl std::error::Error for Trap {} -fn last_trap() -> Trap { - RECORDED_TRAP - .with(|data| data.replace(None)) - .expect("trap_message must be called after trap occurred") -} - fn trap_code_to_expected_string(trap_code: ir::TrapCode) -> String { use ir::TrapCode::*; match trap_code { @@ -170,14 +135,19 @@ pub unsafe extern "C" fn wasmtime_call_trampoline( callee: *const VMFunctionBody, values_vec: *mut u8, ) -> Result<(), Trap> { - if WasmtimeCallTrampoline( - vmctx as *mut u8, - caller_vmctx as *mut u8, - callee, - values_vec, - ) == 0 - { - Err(last_trap()) + let cx = CallThreadState::new(); + let ret = tls::set(&cx, || { + WasmtimeCallTrampoline( + cx.jmp_buf.as_ptr(), + vmctx as *mut u8, + caller_vmctx as *mut u8, + callee, + values_vec, + ) + }); + + if ret == 0 { + Err(cx.unwrap_trap()) } else { Ok(()) } @@ -191,9 +161,91 @@ pub unsafe extern "C" fn wasmtime_call( caller_vmctx: *mut VMContext, callee: *const VMFunctionBody, ) -> Result<(), Trap> { - if WasmtimeCall(vmctx as *mut u8, caller_vmctx as *mut u8, callee) == 0 { - Err(last_trap()) + let cx = CallThreadState::new(); + let ret = tls::set(&cx, || { + WasmtimeCall( + cx.jmp_buf.as_ptr(), + vmctx as *mut u8, + caller_vmctx as *mut u8, + callee, + ) + }); + if ret == 0 { + Err(cx.unwrap_trap()) } else { Ok(()) } } + +/// Temporary state stored on the stack which is registered in the `tls` module +/// below for calls into wasm. +pub struct CallThreadState { + trap: Cell>, + jmp_buf: Cell<*const u8>, + reset_guard_page: Cell, +} + +impl CallThreadState { + fn new() -> CallThreadState { + CallThreadState { + trap: Cell::new(None), + jmp_buf: Cell::new(ptr::null()), + reset_guard_page: Cell::new(false), + } + } + + fn unwrap_trap(self) -> Trap { + self.trap + .replace(None) + .expect("unwrap_trap must be called after trap occurred") + } +} + +impl Drop for CallThreadState { + fn drop(&mut self) { + if self.reset_guard_page.get() { + reset_guard_page(); + } + } +} + +// A private inner module for managing the TLS state that we require across +// calls in wasm. The WebAssembly code is called from C++ and then a trap may +// happen which requires us to read some contextual state to figure out what to +// do with the trap. This `tls` module is used to persist that information from +// the caller to the trap site. +mod tls { + use super::CallThreadState; + use std::cell::Cell; + use std::ptr; + + thread_local!(static PTR: Cell<*const CallThreadState> = Cell::new(ptr::null())); + + /// Configures thread local state such that for the duration of the + /// execution of `closure` any call to `with` will yield `ptr`, unless this + /// is recursively called again. + pub fn set(ptr: &CallThreadState, closure: impl FnOnce() -> R) -> R { + struct Reset<'a, T: Copy>(&'a Cell, T); + + impl Drop for Reset<'_, T> { + fn drop(&mut self) { + self.0.set(self.1); + } + } + + PTR.with(|p| { + let _r = Reset(p, p.replace(ptr)); + closure() + }) + } + + /// Returns the last pointer configured with `set` above. Panics if `set` + /// has not been previously called. + pub fn with(closure: impl FnOnce(&CallThreadState) -> R) -> R { + PTR.with(|ptr| { + let p = ptr.get(); + assert!(!p.is_null()); + unsafe { closure(&*p) } + }) + } +}