From d719ec7e1cea5e3751d065c3ab9fa5813dffeee2 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Wed, 29 Apr 2020 14:24:54 -0500 Subject: [PATCH] Don't try to handle non-wasmtime segfaults (#1577) This commit fixes an issue in Wasmtime where Wasmtime would accidentally "handle" non-wasm segfaults while executing host imports of wasm modules. If a host import segfaulted then Wasmtime would recognize that wasm code is on the stack, so it'd longjmp out of the wasm code. This papers over real bugs though in host code and erroneously classified segfaults as wasm traps. The fix here was to add a check to our wasm signal handler for if the faulting address falls in JIT code itself. Actually threading through all the right information for that check to happen is a bit tricky, though, so this involved some refactoring: * A closure parameter to `catch_traps` was added. This closure is responsible for classifying addresses as whether or not they fall in JIT code. Anything returning `false` means that the trap won't get handled and we'll forward to the next signal handler. * To avoid passing tons of context all over the place, the start function is now no longer automatically invoked by `InstanceHandle`. This avoids the need for passing all sorts of trap-handling contextual information like the maximum stack size and "is this a jit address" closure. Instead creators of `InstanceHandle` (like wasmtime) are now responsible for invoking the start function. * To avoid excessive use of `transmute` with lifetimes since the traphandler state now has a lifetime the per-instance custom signal handler is now replaced with a per-store custom signal handler. I'm not entirely certain the purpose of the custom signal handler, though, so I'd look for feedback on this part. A new test has been added which ensures that if a host function segfaults we don't accidentally try to handle it, and instead we correctly report the segfault. --- crates/api/src/func.rs | 47 +++++---- crates/api/src/instance.rs | 22 ++++- crates/api/src/runtime.rs | 16 +++- crates/api/src/unix.rs | 10 +- crates/api/src/windows.rs | 10 +- crates/jit/src/code_memory.rs | 14 +++ crates/jit/src/compiler.rs | 6 ++ crates/runtime/src/instance.rs | 91 +----------------- crates/runtime/src/lib.rs | 5 +- crates/runtime/src/traphandlers.rs | 149 ++++++++++++++++------------- tests/all/custom_signal_handler.rs | 60 ++++++------ tests/host_segfault.rs | 6 ++ 12 files changed, 214 insertions(+), 222 deletions(-) diff --git a/crates/api/src/func.rs b/crates/api/src/func.rs index ac1838a5e4..b1b18be7c1 100644 --- a/crates/api/src/func.rs +++ b/crates/api/src/func.rs @@ -178,7 +178,6 @@ macro_rules! getters { // of the closure. Pass the export in so that we can call it. let instance = self.instance.clone(); let export = self.export.clone(); - let max_wasm_stack = self.store().engine().config().max_wasm_stack; // ... and then once we've passed the typechecks we can hand out our // object since our `transmute` below should be safe! @@ -194,13 +193,9 @@ macro_rules! getters { >(export.address); let mut ret = None; $(let $args = $args.into_abi();)* - wasmtime_runtime::catch_traps(export.vmctx, max_wasm_stack, || { + catch_traps(export.vmctx, &instance.store, || { ret = Some(fnptr(export.vmctx, ptr::null_mut(), $($args,)*)); - }).map_err(Trap::from_jit)?; - - // We're holding this handle just to ensure that the instance stays - // live while we call into it. - drop(&instance); + })?; Ok(ret.unwrap()) } @@ -562,22 +557,14 @@ impl Func { } // Call the trampoline. - if let Err(error) = unsafe { - wasmtime_runtime::catch_traps( + catch_traps(self.export.vmctx, &self.instance.store, || unsafe { + (self.trampoline)( self.export.vmctx, - self.instance.store.engine().config().max_wasm_stack, - || { - (self.trampoline)( - self.export.vmctx, - ptr::null_mut(), - self.export.address, - values_vec.as_mut_ptr(), - ) - }, + ptr::null_mut(), + self.export.address, + values_vec.as_mut_ptr(), ) - } { - return Err(Trap::from_jit(error).into()); - } + })?; // Load the return values out of `values_vec`. let mut results = Vec::with_capacity(my_ty.results().len()); @@ -746,6 +733,24 @@ impl fmt::Debug for Func { } } +pub(crate) fn catch_traps( + vmctx: *mut VMContext, + store: &Store, + closure: impl FnMut(), +) -> Result<(), Trap> { + let signalhandler = store.signal_handler(); + unsafe { + wasmtime_runtime::catch_traps( + vmctx, + store.engine().config().max_wasm_stack, + |addr| store.compiler().is_in_jit_code(addr), + signalhandler.as_deref(), + closure, + ) + .map_err(Trap::from_jit) + } +} + /// A trait implemented for types which can be arguments to closures passed to /// [`Func::wrap`] and friends. /// diff --git a/crates/api/src/instance.rs b/crates/api/src/instance.rs index 02c535e773..808f5d3979 100644 --- a/crates/api/src/instance.rs +++ b/crates/api/src/instance.rs @@ -2,8 +2,10 @@ use crate::trampoline::StoreInstanceHandle; use crate::{Export, Extern, Func, Global, Memory, Module, Store, Table, Trap}; use anyhow::{bail, Error, Result}; use std::any::Any; +use std::mem; +use wasmtime_environ::EntityIndex; use wasmtime_jit::{CompiledModule, Resolver}; -use wasmtime_runtime::{InstantiationError, SignatureRegistry}; +use wasmtime_runtime::{InstantiationError, SignatureRegistry, VMContext, VMFunctionBody}; struct SimpleResolver<'a> { imports: &'a [Extern], @@ -45,7 +47,6 @@ fn instantiate( instance .initialize( config.validating_config.operator_config.enable_bulk_memory, - config.max_wasm_stack, &compiled_module.data_initializers(), ) .map_err(|e| -> Error { @@ -56,6 +57,23 @@ fn instantiate( other => other.into(), } })?; + + // If a start function is present, now that we've got our compiled + // instance we can invoke it. Make sure we use all the trap-handling + // configuration in `store` as well. + if let Some(start) = instance.module().start_func { + let f = match instance.lookup_by_declaration(&EntityIndex::Function(start)) { + wasmtime_runtime::Export::Function(f) => f, + _ => unreachable!(), // valid modules shouldn't hit this + }; + super::func::catch_traps(instance.vmctx_ptr(), store, || { + mem::transmute::< + *const VMFunctionBody, + unsafe extern "C" fn(*mut VMContext, *mut VMContext), + >(f.address)(f.vmctx, instance.vmctx_ptr()) + })?; + } + Ok(instance) } } diff --git a/crates/api/src/runtime.rs b/crates/api/src/runtime.rs index 6592cc93df..d81b11a064 100644 --- a/crates/api/src/runtime.rs +++ b/crates/api/src/runtime.rs @@ -12,7 +12,9 @@ use wasmtime_environ::settings::{self, Configurable}; use wasmtime_environ::{CacheConfig, Tunables}; use wasmtime_jit::{native, CompilationStrategy, Compiler}; use wasmtime_profiling::{JitDumpAgent, NullProfilerAgent, ProfilingAgent, VTuneAgent}; -use wasmtime_runtime::{debug_builtins, InstanceHandle, RuntimeMemoryCreator, VMInterrupts}; +use wasmtime_runtime::{ + debug_builtins, InstanceHandle, RuntimeMemoryCreator, SignalHandler, VMInterrupts, +}; // Runtime Environment @@ -557,6 +559,7 @@ pub(crate) struct StoreInner { engine: Engine, compiler: RefCell, instances: RefCell>, + signal_handler: RefCell>>>, } impl Store { @@ -574,6 +577,7 @@ impl Store { engine: engine.clone(), compiler: RefCell::new(compiler), instances: RefCell::new(Vec::new()), + signal_handler: RefCell::new(None), }), } } @@ -625,6 +629,16 @@ impl Store { Rc::downgrade(&self.inner) } + pub(crate) fn signal_handler(&self) -> std::cell::Ref<'_, Option>>> { + self.inner.signal_handler.borrow() + } + + pub(crate) fn signal_handler_mut( + &self, + ) -> std::cell::RefMut<'_, Option>>> { + self.inner.signal_handler.borrow_mut() + } + /// Returns whether the stores `a` and `b` refer to the same underlying /// `Store`. /// diff --git a/crates/api/src/unix.rs b/crates/api/src/unix.rs index 22fa392597..4e2f8c61c5 100644 --- a/crates/api/src/unix.rs +++ b/crates/api/src/unix.rs @@ -9,10 +9,10 @@ //! throughout the `wasmtime` crate with extra functionality that's only //! available on Unix. -use crate::Instance; +use crate::Store; -/// Extensions for the [`Instance`] type only available on Unix. -pub trait InstanceExt { +/// Extensions for the [`Store`] type only available on Unix. +pub trait StoreExt { // TODO: needs more docs? /// The signal handler must be /// [async-signal-safe](http://man7.org/linux/man-pages/man7/signal-safety.7.html). @@ -21,11 +21,11 @@ pub trait InstanceExt { H: 'static + Fn(libc::c_int, *const libc::siginfo_t, *const libc::c_void) -> bool; } -impl InstanceExt for Instance { +impl StoreExt for Store { unsafe fn set_signal_handler(&self, handler: H) where H: 'static + Fn(libc::c_int, *const libc::siginfo_t, *const libc::c_void) -> bool, { - self.handle.set_signal_handler(handler); + *self.signal_handler_mut() = Some(Box::new(handler)); } } diff --git a/crates/api/src/windows.rs b/crates/api/src/windows.rs index 5725007d53..f691e1aa35 100644 --- a/crates/api/src/windows.rs +++ b/crates/api/src/windows.rs @@ -9,10 +9,10 @@ //! throughout the `wasmtime` crate with extra functionality that's only //! available on Windows. -use crate::Instance; +use crate::Store; -/// Extensions for the [`Instance`] type only available on Windows. -pub trait InstanceExt { +/// Extensions for the [`Store`] type only available on Windows. +pub trait StoreExt { /// Configures a custom signal handler to execute. /// /// TODO: needs more documentation. @@ -21,11 +21,11 @@ pub trait InstanceExt { H: 'static + Fn(winapi::um::winnt::PEXCEPTION_POINTERS) -> bool; } -impl InstanceExt for Instance { +impl StoreExt for Store { unsafe fn set_signal_handler(&self, handler: H) where H: 'static + Fn(winapi::um::winnt::PEXCEPTION_POINTERS) -> bool, { - self.handle.set_signal_handler(handler); + *self.signal_handler_mut() = Some(Box::new(handler)); } } diff --git a/crates/jit/src/code_memory.rs b/crates/jit/src/code_memory.rs index 6d90c85710..e0b7315c4d 100644 --- a/crates/jit/src/code_memory.rs +++ b/crates/jit/src/code_memory.rs @@ -21,6 +21,12 @@ impl CodeMemoryEntry { let registry = ManuallyDrop::new(UnwindRegistry::new(mmap.as_ptr() as usize)); Ok(Self { mmap, registry }) } + + fn contains(&self, addr: usize) -> bool { + let start = self.mmap.as_ptr() as usize; + let end = start + self.mmap.len(); + start <= addr && addr < end + } } impl Drop for CodeMemoryEntry { @@ -236,4 +242,12 @@ impl CodeMemory { Ok(()) } + + /// Returns whether any published segment of this code memory contains + /// `addr`. + pub fn published_contains(&self, addr: usize) -> bool { + self.entries[..self.published] + .iter() + .any(|entry| entry.contains(addr)) + } } diff --git a/crates/jit/src/compiler.rs b/crates/jit/src/compiler.rs index d0307fc32c..0bcef5f0fc 100644 --- a/crates/jit/src/compiler.rs +++ b/crates/jit/src/compiler.rs @@ -235,6 +235,12 @@ impl Compiler { pub fn signatures(&self) -> &SignatureRegistry { &self.signatures } + + /// Returns whether or not the given address falls within the JIT code + /// managed by the compiler + pub fn is_in_jit_code(&self, addr: usize) -> bool { + self.code_memory.published_contains(addr) + } } /// Create a trampoline for invoking a function. diff --git a/crates/runtime/src/instance.rs b/crates/runtime/src/instance.rs index bbf4ebfaba..1d5b17f95e 100644 --- a/crates/runtime/src/instance.rs +++ b/crates/runtime/src/instance.rs @@ -8,7 +8,7 @@ use crate::jit_int::GdbJitImageRegistration; use crate::memory::{DefaultMemoryCreator, RuntimeLinearMemory, RuntimeMemoryCreator}; use crate::table::Table; use crate::traphandlers; -use crate::traphandlers::{catch_traps, Trap}; +use crate::traphandlers::Trap; use crate::vmcontext::{ VMBuiltinFunctionsArray, VMCallerCheckedAnyfunc, VMContext, VMFunctionBody, VMFunctionImport, VMGlobalDefinition, VMGlobalImport, VMInterrupts, VMMemoryDefinition, VMMemoryImport, @@ -19,7 +19,7 @@ use memoffset::offset_of; use more_asserts::assert_lt; use std::alloc::{self, Layout}; use std::any::Any; -use std::cell::{Cell, RefCell}; +use std::cell::RefCell; use std::collections::HashMap; use std::convert::TryFrom; use std::rc::Rc; @@ -33,34 +33,6 @@ use wasmtime_environ::wasm::{ }; use wasmtime_environ::{ir, DataInitializer, EntityIndex, Module, TableElements, VMOffsets}; -cfg_if::cfg_if! { - if #[cfg(unix)] { - pub type SignalHandler = dyn Fn(libc::c_int, *const libc::siginfo_t, *const libc::c_void) -> bool; - - impl InstanceHandle { - /// Set a custom signal handler - pub fn set_signal_handler(&self, handler: H) - where - H: 'static + Fn(libc::c_int, *const libc::siginfo_t, *const libc::c_void) -> bool, - { - self.instance().signal_handler.set(Some(Box::new(handler))); - } - } - } else if #[cfg(target_os = "windows")] { - pub type SignalHandler = dyn Fn(winapi::um::winnt::PEXCEPTION_POINTERS) -> bool; - - impl InstanceHandle { - /// Set a custom signal handler - pub fn set_signal_handler(&self, handler: H) - where - H: 'static + Fn(winapi::um::winnt::PEXCEPTION_POINTERS) -> bool, - { - self.instance().signal_handler.set(Some(Box::new(handler))); - } - } - } -} - /// A WebAssembly instance. /// /// This is repr(C) to ensure that the vmctx field is last. @@ -99,9 +71,6 @@ pub(crate) struct Instance { /// Optional image of JIT'ed code for debugger registration. dbg_jit_registration: Option>, - /// Handler run when `SIGBUS`, `SIGFPE`, `SIGILL`, or `SIGSEGV` are caught by the instance thread. - pub(crate) signal_handler: Cell>>, - /// Externally allocated data indicating how this instance will be /// interrupted. pub(crate) interrupts: Arc, @@ -377,58 +346,6 @@ impl Instance { &*self.host_state } - /// Invoke the WebAssembly start function of the instance, if one is present. - fn invoke_start_function(&self, max_wasm_stack: usize) -> Result<(), InstantiationError> { - let start_index = match self.module.start_func { - Some(idx) => idx, - None => return Ok(()), - }; - - self.invoke_function_index(start_index, max_wasm_stack) - .map_err(InstantiationError::StartTrap) - } - - fn invoke_function_index( - &self, - callee_index: FuncIndex, - max_wasm_stack: usize, - ) -> Result<(), Trap> { - let (callee_address, callee_vmctx) = - match self.module.local.defined_func_index(callee_index) { - Some(defined_index) => { - let body = *self - .finished_functions - .get(defined_index) - .expect("function index is out of bounds"); - (body as *const _, self.vmctx_ptr()) - } - None => { - assert_lt!(callee_index.index(), self.module.local.num_imported_funcs); - let import = self.imported_function(callee_index); - (import.body, import.vmctx) - } - }; - - self.invoke_function(callee_vmctx, callee_address, max_wasm_stack) - } - - fn invoke_function( - &self, - callee_vmctx: *mut VMContext, - callee_address: *const VMFunctionBody, - max_wasm_stack: usize, - ) -> Result<(), Trap> { - // Make the call. - unsafe { - catch_traps(callee_vmctx, max_wasm_stack, || { - mem::transmute::< - *const VMFunctionBody, - unsafe extern "C" fn(*mut VMContext, *mut VMContext), - >(callee_address)(callee_vmctx, self.vmctx_ptr()) - }) - } - } - /// Return the offset from the vmctx pointer to its containing Instance. #[inline] pub(crate) fn vmctx_offset() -> isize { @@ -908,7 +825,6 @@ impl InstanceHandle { trampolines, dbg_jit_registration, host_state, - signal_handler: Cell::new(None), interrupts, vmctx: VMContext {}, }; @@ -988,7 +904,6 @@ impl InstanceHandle { pub unsafe fn initialize( &self, is_bulk_memory: bool, - max_wasm_stack: usize, data_initializers: &[DataInitializer<'_>], ) -> Result<(), InstantiationError> { // Check initializer bounds before initializing anything. Only do this @@ -1005,8 +920,6 @@ impl InstanceHandle { initialize_tables(self.instance())?; initialize_memories(self.instance(), data_initializers)?; - // And finally, invoke the start function. - self.instance().invoke_start_function(max_wasm_stack)?; Ok(()) } diff --git a/crates/runtime/src/lib.rs b/crates/runtime/src/lib.rs index 94816f6f13..3951355963 100644 --- a/crates/runtime/src/lib.rs +++ b/crates/runtime/src/lib.rs @@ -43,8 +43,9 @@ pub use crate::memory::{RuntimeLinearMemory, RuntimeMemoryCreator}; pub use crate::mmap::Mmap; pub use crate::sig_registry::SignatureRegistry; pub use crate::table::Table; -pub use crate::traphandlers::resume_panic; -pub use crate::traphandlers::{catch_traps, raise_lib_trap, raise_user_trap, Trap}; +pub use crate::traphandlers::{ + catch_traps, raise_lib_trap, raise_user_trap, resume_panic, SignalHandler, Trap, +}; pub use crate::vmcontext::{ VMCallerCheckedAnyfunc, VMContext, VMFunctionBody, VMFunctionImport, VMGlobalDefinition, VMGlobalImport, VMInterrupts, VMInvokeArgument, VMMemoryDefinition, VMMemoryImport, diff --git a/crates/runtime/src/traphandlers.rs b/crates/runtime/src/traphandlers.rs index 488dc3a059..d21568b90c 100644 --- a/crates/runtime/src/traphandlers.rs +++ b/crates/runtime/src/traphandlers.rs @@ -1,7 +1,6 @@ //! WebAssembly trap handling, which is built on top of the lower-level //! signalhandling mechanisms. -use crate::instance::{InstanceHandle, SignalHandler}; use crate::VMContext; use backtrace::Backtrace; use std::any::Any; @@ -26,6 +25,9 @@ cfg_if::cfg_if! { if #[cfg(unix)] { use std::mem::{self, MaybeUninit}; + /// Function which may handle custom signals while processing traps. + pub type SignalHandler<'a> = dyn Fn(libc::c_int, *const libc::siginfo_t, *const libc::c_void) -> bool + 'a; + static mut PREV_SIGSEGV: MaybeUninit = MaybeUninit::uninit(); static mut PREV_SIGBUS: MaybeUninit = MaybeUninit::uninit(); static mut PREV_SIGILL: MaybeUninit = MaybeUninit::uninit(); @@ -180,6 +182,9 @@ cfg_if::cfg_if! { use winapi::um::minwinbase::*; use winapi::vc::excpt::*; + /// Function which may handle custom signals while processing traps. + pub type SignalHandler<'a> = dyn Fn(winapi::um::winnt::PEXCEPTION_POINTERS) -> bool + 'a; + unsafe fn platform_init() { // our trap handler needs to go first, so that we can recover from // wasm faults and continue execution, so pass `1` as a true value @@ -361,6 +366,8 @@ impl Trap { pub unsafe fn catch_traps( vmctx: *mut VMContext, max_wasm_stack: usize, + is_wasm_code: impl Fn(usize) -> bool, + signal_handler: Option<&SignalHandler>, mut closure: F, ) -> Result<(), Trap> where @@ -370,7 +377,7 @@ where #[cfg(unix)] setup_unix_sigaltstack()?; - return CallThreadState::new(vmctx).with(max_wasm_stack, |cx| { + return CallThreadState::new(vmctx, &is_wasm_code, signal_handler).with(max_wasm_stack, |cx| { RegisterSetjmp( cx.jmp_buf.as_ptr(), call_closure::, @@ -388,12 +395,13 @@ where /// Temporary state stored on the stack which is registered in the `tls` module /// below for calls into wasm. -pub struct CallThreadState { +pub struct CallThreadState<'a> { unwind: Cell, jmp_buf: Cell<*const u8>, - prev: Option<*const CallThreadState>, vmctx: *mut VMContext, handling_trap: Cell, + is_wasm_code: &'a (dyn Fn(usize) -> bool + 'a), + signal_handler: Option<&'a SignalHandler<'a>>, } enum UnwindReason { @@ -404,54 +412,56 @@ enum UnwindReason { JitTrap { backtrace: Backtrace, pc: usize }, } -impl CallThreadState { - fn new(vmctx: *mut VMContext) -> CallThreadState { +impl<'a> CallThreadState<'a> { + fn new( + vmctx: *mut VMContext, + is_wasm_code: &'a (dyn Fn(usize) -> bool + 'a), + signal_handler: Option<&'a SignalHandler<'a>>, + ) -> CallThreadState<'a> { CallThreadState { unwind: Cell::new(UnwindReason::None), vmctx, jmp_buf: Cell::new(ptr::null()), - prev: None, handling_trap: Cell::new(false), + is_wasm_code, + signal_handler, } } fn with( - mut self, + self, max_wasm_stack: usize, closure: impl FnOnce(&CallThreadState) -> i32, ) -> Result<(), Trap> { - tls::with(|prev| { - self.prev = prev.map(|p| p as *const _); - let _reset = self.update_stack_limit(max_wasm_stack)?; - let ret = tls::set(&self, || closure(&self)); - match self.unwind.replace(UnwindReason::None) { - UnwindReason::None => { - debug_assert_eq!(ret, 1); - Ok(()) - } - UnwindReason::UserTrap(data) => { - debug_assert_eq!(ret, 0); - Err(Trap::User(data)) - } - UnwindReason::LibTrap(trap) => Err(trap), - UnwindReason::JitTrap { backtrace, pc } => { - debug_assert_eq!(ret, 0); - let maybe_interrupted = unsafe { - (*self.vmctx).instance().interrupts.stack_limit.load(SeqCst) - == wasmtime_environ::INTERRUPTED - }; - Err(Trap::Jit { - pc, - backtrace, - maybe_interrupted, - }) - } - UnwindReason::Panic(panic) => { - debug_assert_eq!(ret, 0); - std::panic::resume_unwind(panic) - } + let _reset = self.update_stack_limit(max_wasm_stack)?; + let ret = tls::set(&self, || closure(&self)); + match self.unwind.replace(UnwindReason::None) { + UnwindReason::None => { + debug_assert_eq!(ret, 1); + Ok(()) } - }) + UnwindReason::UserTrap(data) => { + debug_assert_eq!(ret, 0); + Err(Trap::User(data)) + } + UnwindReason::LibTrap(trap) => Err(trap), + UnwindReason::JitTrap { backtrace, pc } => { + debug_assert_eq!(ret, 0); + let maybe_interrupted = unsafe { + (*self.vmctx).instance().interrupts.stack_limit.load(SeqCst) + == wasmtime_environ::INTERRUPTED + }; + Err(Trap::Jit { + pc, + backtrace, + maybe_interrupted, + }) + } + UnwindReason::Panic(panic) => { + debug_assert_eq!(ret, 0); + std::panic::resume_unwind(panic) + } + } } /// Checks and/or initializes the wasm native call stack limit. @@ -535,18 +545,6 @@ impl CallThreadState { Ok(Reset(reset_stack_limit, &interrupts.stack_limit)) } - fn any_instance(&self, func: impl Fn(&InstanceHandle) -> bool) -> bool { - unsafe { - if func(&InstanceHandle::from_vmctx(self.vmctx)) { - return true; - } - match self.prev { - Some(prev) => (*prev).any_instance(func), - None => false, - } - } - } - fn unwind_with(&self, reason: UnwindReason) -> ! { self.unwind.replace(reason); unsafe { @@ -582,21 +580,25 @@ impl CallThreadState { if self.handling_trap.replace(true) { return ptr::null(); } + let _reset = ResetCell(&self.handling_trap, false); + + // If we haven't even started to handle traps yet, bail out. + if self.jmp_buf.get().is_null() { + return ptr::null(); + } + + // If this fault wasn't in wasm code, then it's not our problem + if !(self.is_wasm_code)(pc as usize) { + return ptr::null(); + } // First up see if any instance registered has a custom trap handler, // in which case run them all. If anything handles the trap then we // return that the trap was handled. - if self.any_instance(|i| { - let handler = match i.instance().signal_handler.replace(None) { - Some(handler) => handler, - None => return false, - }; - let result = call_handler(&handler); - i.instance().signal_handler.set(Some(handler)); - return result; - }) { - self.handling_trap.set(false); - return 1 as *const _; + if let Some(handler) = self.signal_handler { + if call_handler(handler) { + return 1 as *const _; + } } // TODO: stack overflow can happen at any random time (i.e. in malloc() @@ -607,7 +609,6 @@ impl CallThreadState { // doesn't trap. Then, if we have called some WebAssembly code, it // means the trap is stack overflow. if self.jmp_buf.get().is_null() { - self.handling_trap.set(false); return ptr::null(); } let backtrace = Backtrace::new_unresolved(); @@ -615,11 +616,18 @@ impl CallThreadState { backtrace, pc: pc as usize, }); - self.handling_trap.set(false); self.jmp_buf.get() } } +struct ResetCell<'a, T: Copy>(&'a Cell, T); + +impl Drop for ResetCell<'_, T> { + fn drop(&mut self) { + self.0.set(self.1); + } +} + // A private inner module for managing the TLS state that we require across // calls in wasm. The WebAssembly code is called from C++ and then a trap may // happen which requires us to read some contextual state to figure out what to @@ -628,14 +636,15 @@ impl CallThreadState { mod tls { use super::CallThreadState; use std::cell::Cell; + use std::mem; use std::ptr; - thread_local!(static PTR: Cell<*const CallThreadState> = Cell::new(ptr::null())); + thread_local!(static PTR: Cell<*const CallThreadState<'static>> = Cell::new(ptr::null())); /// Configures thread local state such that for the duration of the /// execution of `closure` any call to `with` will yield `ptr`, unless this /// is recursively called again. - pub fn set(ptr: &CallThreadState, closure: impl FnOnce() -> R) -> R { + pub fn set(ptr: &CallThreadState<'_>, closure: impl FnOnce() -> R) -> R { struct Reset<'a, T: Copy>(&'a Cell, T); impl Drop for Reset<'_, T> { @@ -645,6 +654,12 @@ mod tls { } PTR.with(|p| { + // Note that this extension of the lifetime to `'static` should be + // safe because we only ever access it below with an anonymous + // lifetime, meaning `'static` never leaks out of this module. + let ptr = unsafe { + mem::transmute::<*const CallThreadState<'_>, *const CallThreadState<'static>>(ptr) + }; let _r = Reset(p, p.replace(ptr)); closure() }) @@ -652,7 +667,7 @@ mod tls { /// Returns the last pointer configured with `set` above. Panics if `set` /// has not been previously called. - pub fn with(closure: impl FnOnce(Option<&CallThreadState>) -> R) -> R { + pub fn with(closure: impl FnOnce(Option<&CallThreadState<'_>>) -> R) -> R { PTR.with(|ptr| { let p = ptr.get(); unsafe { closure(if p.is_null() { None } else { Some(&*p) }) } diff --git a/tests/all/custom_signal_handler.rs b/tests/all/custom_signal_handler.rs index c75821ef17..ba5a96b162 100644 --- a/tests/all/custom_signal_handler.rs +++ b/tests/all/custom_signal_handler.rs @@ -3,7 +3,7 @@ mod tests { use anyhow::Result; use std::rc::Rc; use std::sync::atomic::{AtomicBool, Ordering}; - use wasmtime::unix::InstanceExt; + use wasmtime::unix::StoreExt; use wasmtime::*; const WAT1: &str = r#" @@ -96,7 +96,7 @@ mod tests { let (base, length) = set_up_memory(&instance); unsafe { - instance.set_signal_handler(move |signum, siginfo, _| { + store.set_signal_handler(move |signum, siginfo, _| { handle_sigsegv(base, length, signum, siginfo) }); } @@ -161,7 +161,7 @@ mod tests { unsafe { let (base1, length1) = set_up_memory(&instance1); - instance1.set_signal_handler({ + store.set_signal_handler({ let instance1_handler_triggered = instance1_handler_triggered.clone(); move |_signum, _siginfo, _context| { // Remove protections so the execution may resume @@ -180,31 +180,6 @@ mod tests { }); } - let instance2 = Instance::new(&module, &[]).expect("failed to instantiate module"); - let instance2_handler_triggered = Rc::new(AtomicBool::new(false)); - - unsafe { - let (base2, length2) = set_up_memory(&instance2); - - instance2.set_signal_handler({ - let instance2_handler_triggered = instance2_handler_triggered.clone(); - move |_signum, _siginfo, _context| { - // Remove protections so the execution may resume - libc::mprotect( - base2 as *mut libc::c_void, - length2, - libc::PROT_READ | libc::PROT_WRITE, - ); - instance2_handler_triggered.store(true, Ordering::SeqCst); - println!( - "Hello from instance2 signal handler! {}", - instance2_handler_triggered.load(Ordering::SeqCst) - ); - true - } - }); - } - // Invoke both instances and trigger both signal handlers // First instance1 @@ -222,6 +197,31 @@ mod tests { ); } + let instance2 = Instance::new(&module, &[]).expect("failed to instantiate module"); + let instance2_handler_triggered = Rc::new(AtomicBool::new(false)); + + unsafe { + let (base2, length2) = set_up_memory(&instance2); + + store.set_signal_handler({ + let instance2_handler_triggered = instance2_handler_triggered.clone(); + move |_signum, _siginfo, _context| { + // Remove protections so the execution may resume + libc::mprotect( + base2 as *mut libc::c_void, + length2, + libc::PROT_READ | libc::PROT_WRITE, + ); + instance2_handler_triggered.store(true, Ordering::SeqCst); + println!( + "Hello from instance2 signal handler! {}", + instance2_handler_triggered.load(Ordering::SeqCst) + ); + true + } + }); + } + // And then instance2 { let mut exports2 = instance2.exports(); @@ -249,7 +249,7 @@ mod tests { let instance1 = Instance::new(&module1, &[])?; let (base1, length1) = set_up_memory(&instance1); unsafe { - instance1.set_signal_handler(move |signum, siginfo, _| { + store.set_signal_handler(move |signum, siginfo, _| { println!("instance1"); handle_sigsegv(base1, length1, signum, siginfo) }); @@ -264,7 +264,7 @@ mod tests { // since 'instance2.run' calls 'instance1.read' we need to set up the signal handler to handle // SIGSEGV originating from within the memory of instance1 unsafe { - instance2.set_signal_handler(move |signum, siginfo, _| { + store.set_signal_handler(move |signum, siginfo, _| { handle_sigsegv(base1, length1, signum, siginfo) }); } diff --git a/tests/host_segfault.rs b/tests/host_segfault.rs index cc29fe87db..1884967287 100644 --- a/tests/host_segfault.rs +++ b/tests/host_segfault.rs @@ -59,6 +59,12 @@ fn main() { let _instance = Instance::new(&module, &[]).unwrap(); println!("stack overrun: {}", overrun_the_stack()); }), + ("segfault in a host function", || { + let store = Store::default(); + let module = Module::new(&store, r#"(import "" "" (func)) (start 0)"#).unwrap(); + let segfault = Func::wrap(&store, || segfault()); + Instance::new(&module, &[segfault.into()]).unwrap(); + }), ]; match env::var(VAR_NAME) { Ok(s) => {