diff --git a/crates/api/src/func.rs b/crates/api/src/func.rs index ac1838a5e4..b1b18be7c1 100644 --- a/crates/api/src/func.rs +++ b/crates/api/src/func.rs @@ -178,7 +178,6 @@ macro_rules! getters { // of the closure. Pass the export in so that we can call it. let instance = self.instance.clone(); let export = self.export.clone(); - let max_wasm_stack = self.store().engine().config().max_wasm_stack; // ... and then once we've passed the typechecks we can hand out our // object since our `transmute` below should be safe! @@ -194,13 +193,9 @@ macro_rules! getters { >(export.address); let mut ret = None; $(let $args = $args.into_abi();)* - wasmtime_runtime::catch_traps(export.vmctx, max_wasm_stack, || { + catch_traps(export.vmctx, &instance.store, || { ret = Some(fnptr(export.vmctx, ptr::null_mut(), $($args,)*)); - }).map_err(Trap::from_jit)?; - - // We're holding this handle just to ensure that the instance stays - // live while we call into it. - drop(&instance); + })?; Ok(ret.unwrap()) } @@ -562,22 +557,14 @@ impl Func { } // Call the trampoline. - if let Err(error) = unsafe { - wasmtime_runtime::catch_traps( + catch_traps(self.export.vmctx, &self.instance.store, || unsafe { + (self.trampoline)( self.export.vmctx, - self.instance.store.engine().config().max_wasm_stack, - || { - (self.trampoline)( - self.export.vmctx, - ptr::null_mut(), - self.export.address, - values_vec.as_mut_ptr(), - ) - }, + ptr::null_mut(), + self.export.address, + values_vec.as_mut_ptr(), ) - } { - return Err(Trap::from_jit(error).into()); - } + })?; // Load the return values out of `values_vec`. let mut results = Vec::with_capacity(my_ty.results().len()); @@ -746,6 +733,24 @@ impl fmt::Debug for Func { } } +pub(crate) fn catch_traps( + vmctx: *mut VMContext, + store: &Store, + closure: impl FnMut(), +) -> Result<(), Trap> { + let signalhandler = store.signal_handler(); + unsafe { + wasmtime_runtime::catch_traps( + vmctx, + store.engine().config().max_wasm_stack, + |addr| store.compiler().is_in_jit_code(addr), + signalhandler.as_deref(), + closure, + ) + .map_err(Trap::from_jit) + } +} + /// A trait implemented for types which can be arguments to closures passed to /// [`Func::wrap`] and friends. /// diff --git a/crates/api/src/instance.rs b/crates/api/src/instance.rs index 02c535e773..808f5d3979 100644 --- a/crates/api/src/instance.rs +++ b/crates/api/src/instance.rs @@ -2,8 +2,10 @@ use crate::trampoline::StoreInstanceHandle; use crate::{Export, Extern, Func, Global, Memory, Module, Store, Table, Trap}; use anyhow::{bail, Error, Result}; use std::any::Any; +use std::mem; +use wasmtime_environ::EntityIndex; use wasmtime_jit::{CompiledModule, Resolver}; -use wasmtime_runtime::{InstantiationError, SignatureRegistry}; +use wasmtime_runtime::{InstantiationError, SignatureRegistry, VMContext, VMFunctionBody}; struct SimpleResolver<'a> { imports: &'a [Extern], @@ -45,7 +47,6 @@ fn instantiate( instance .initialize( config.validating_config.operator_config.enable_bulk_memory, - config.max_wasm_stack, &compiled_module.data_initializers(), ) .map_err(|e| -> Error { @@ -56,6 +57,23 @@ fn instantiate( other => other.into(), } })?; + + // If a start function is present, now that we've got our compiled + // instance we can invoke it. Make sure we use all the trap-handling + // configuration in `store` as well. + if let Some(start) = instance.module().start_func { + let f = match instance.lookup_by_declaration(&EntityIndex::Function(start)) { + wasmtime_runtime::Export::Function(f) => f, + _ => unreachable!(), // valid modules shouldn't hit this + }; + super::func::catch_traps(instance.vmctx_ptr(), store, || { + mem::transmute::< + *const VMFunctionBody, + unsafe extern "C" fn(*mut VMContext, *mut VMContext), + >(f.address)(f.vmctx, instance.vmctx_ptr()) + })?; + } + Ok(instance) } } diff --git a/crates/api/src/runtime.rs b/crates/api/src/runtime.rs index 6592cc93df..d81b11a064 100644 --- a/crates/api/src/runtime.rs +++ b/crates/api/src/runtime.rs @@ -12,7 +12,9 @@ use wasmtime_environ::settings::{self, Configurable}; use wasmtime_environ::{CacheConfig, Tunables}; use wasmtime_jit::{native, CompilationStrategy, Compiler}; use wasmtime_profiling::{JitDumpAgent, NullProfilerAgent, ProfilingAgent, VTuneAgent}; -use wasmtime_runtime::{debug_builtins, InstanceHandle, RuntimeMemoryCreator, VMInterrupts}; +use wasmtime_runtime::{ + debug_builtins, InstanceHandle, RuntimeMemoryCreator, SignalHandler, VMInterrupts, +}; // Runtime Environment @@ -557,6 +559,7 @@ pub(crate) struct StoreInner { engine: Engine, compiler: RefCell, instances: RefCell>, + signal_handler: RefCell>>>, } impl Store { @@ -574,6 +577,7 @@ impl Store { engine: engine.clone(), compiler: RefCell::new(compiler), instances: RefCell::new(Vec::new()), + signal_handler: RefCell::new(None), }), } } @@ -625,6 +629,16 @@ impl Store { Rc::downgrade(&self.inner) } + pub(crate) fn signal_handler(&self) -> std::cell::Ref<'_, Option>>> { + self.inner.signal_handler.borrow() + } + + pub(crate) fn signal_handler_mut( + &self, + ) -> std::cell::RefMut<'_, Option>>> { + self.inner.signal_handler.borrow_mut() + } + /// Returns whether the stores `a` and `b` refer to the same underlying /// `Store`. /// diff --git a/crates/api/src/unix.rs b/crates/api/src/unix.rs index 22fa392597..4e2f8c61c5 100644 --- a/crates/api/src/unix.rs +++ b/crates/api/src/unix.rs @@ -9,10 +9,10 @@ //! throughout the `wasmtime` crate with extra functionality that's only //! available on Unix. -use crate::Instance; +use crate::Store; -/// Extensions for the [`Instance`] type only available on Unix. -pub trait InstanceExt { +/// Extensions for the [`Store`] type only available on Unix. +pub trait StoreExt { // TODO: needs more docs? /// The signal handler must be /// [async-signal-safe](http://man7.org/linux/man-pages/man7/signal-safety.7.html). @@ -21,11 +21,11 @@ pub trait InstanceExt { H: 'static + Fn(libc::c_int, *const libc::siginfo_t, *const libc::c_void) -> bool; } -impl InstanceExt for Instance { +impl StoreExt for Store { unsafe fn set_signal_handler(&self, handler: H) where H: 'static + Fn(libc::c_int, *const libc::siginfo_t, *const libc::c_void) -> bool, { - self.handle.set_signal_handler(handler); + *self.signal_handler_mut() = Some(Box::new(handler)); } } diff --git a/crates/api/src/windows.rs b/crates/api/src/windows.rs index 5725007d53..f691e1aa35 100644 --- a/crates/api/src/windows.rs +++ b/crates/api/src/windows.rs @@ -9,10 +9,10 @@ //! throughout the `wasmtime` crate with extra functionality that's only //! available on Windows. -use crate::Instance; +use crate::Store; -/// Extensions for the [`Instance`] type only available on Windows. -pub trait InstanceExt { +/// Extensions for the [`Store`] type only available on Windows. +pub trait StoreExt { /// Configures a custom signal handler to execute. /// /// TODO: needs more documentation. @@ -21,11 +21,11 @@ pub trait InstanceExt { H: 'static + Fn(winapi::um::winnt::PEXCEPTION_POINTERS) -> bool; } -impl InstanceExt for Instance { +impl StoreExt for Store { unsafe fn set_signal_handler(&self, handler: H) where H: 'static + Fn(winapi::um::winnt::PEXCEPTION_POINTERS) -> bool, { - self.handle.set_signal_handler(handler); + *self.signal_handler_mut() = Some(Box::new(handler)); } } diff --git a/crates/jit/src/code_memory.rs b/crates/jit/src/code_memory.rs index 6d90c85710..e0b7315c4d 100644 --- a/crates/jit/src/code_memory.rs +++ b/crates/jit/src/code_memory.rs @@ -21,6 +21,12 @@ impl CodeMemoryEntry { let registry = ManuallyDrop::new(UnwindRegistry::new(mmap.as_ptr() as usize)); Ok(Self { mmap, registry }) } + + fn contains(&self, addr: usize) -> bool { + let start = self.mmap.as_ptr() as usize; + let end = start + self.mmap.len(); + start <= addr && addr < end + } } impl Drop for CodeMemoryEntry { @@ -236,4 +242,12 @@ impl CodeMemory { Ok(()) } + + /// Returns whether any published segment of this code memory contains + /// `addr`. + pub fn published_contains(&self, addr: usize) -> bool { + self.entries[..self.published] + .iter() + .any(|entry| entry.contains(addr)) + } } diff --git a/crates/jit/src/compiler.rs b/crates/jit/src/compiler.rs index d0307fc32c..0bcef5f0fc 100644 --- a/crates/jit/src/compiler.rs +++ b/crates/jit/src/compiler.rs @@ -235,6 +235,12 @@ impl Compiler { pub fn signatures(&self) -> &SignatureRegistry { &self.signatures } + + /// Returns whether or not the given address falls within the JIT code + /// managed by the compiler + pub fn is_in_jit_code(&self, addr: usize) -> bool { + self.code_memory.published_contains(addr) + } } /// Create a trampoline for invoking a function. diff --git a/crates/runtime/src/instance.rs b/crates/runtime/src/instance.rs index bbf4ebfaba..1d5b17f95e 100644 --- a/crates/runtime/src/instance.rs +++ b/crates/runtime/src/instance.rs @@ -8,7 +8,7 @@ use crate::jit_int::GdbJitImageRegistration; use crate::memory::{DefaultMemoryCreator, RuntimeLinearMemory, RuntimeMemoryCreator}; use crate::table::Table; use crate::traphandlers; -use crate::traphandlers::{catch_traps, Trap}; +use crate::traphandlers::Trap; use crate::vmcontext::{ VMBuiltinFunctionsArray, VMCallerCheckedAnyfunc, VMContext, VMFunctionBody, VMFunctionImport, VMGlobalDefinition, VMGlobalImport, VMInterrupts, VMMemoryDefinition, VMMemoryImport, @@ -19,7 +19,7 @@ use memoffset::offset_of; use more_asserts::assert_lt; use std::alloc::{self, Layout}; use std::any::Any; -use std::cell::{Cell, RefCell}; +use std::cell::RefCell; use std::collections::HashMap; use std::convert::TryFrom; use std::rc::Rc; @@ -33,34 +33,6 @@ use wasmtime_environ::wasm::{ }; use wasmtime_environ::{ir, DataInitializer, EntityIndex, Module, TableElements, VMOffsets}; -cfg_if::cfg_if! { - if #[cfg(unix)] { - pub type SignalHandler = dyn Fn(libc::c_int, *const libc::siginfo_t, *const libc::c_void) -> bool; - - impl InstanceHandle { - /// Set a custom signal handler - pub fn set_signal_handler(&self, handler: H) - where - H: 'static + Fn(libc::c_int, *const libc::siginfo_t, *const libc::c_void) -> bool, - { - self.instance().signal_handler.set(Some(Box::new(handler))); - } - } - } else if #[cfg(target_os = "windows")] { - pub type SignalHandler = dyn Fn(winapi::um::winnt::PEXCEPTION_POINTERS) -> bool; - - impl InstanceHandle { - /// Set a custom signal handler - pub fn set_signal_handler(&self, handler: H) - where - H: 'static + Fn(winapi::um::winnt::PEXCEPTION_POINTERS) -> bool, - { - self.instance().signal_handler.set(Some(Box::new(handler))); - } - } - } -} - /// A WebAssembly instance. /// /// This is repr(C) to ensure that the vmctx field is last. @@ -99,9 +71,6 @@ pub(crate) struct Instance { /// Optional image of JIT'ed code for debugger registration. dbg_jit_registration: Option>, - /// Handler run when `SIGBUS`, `SIGFPE`, `SIGILL`, or `SIGSEGV` are caught by the instance thread. - pub(crate) signal_handler: Cell>>, - /// Externally allocated data indicating how this instance will be /// interrupted. pub(crate) interrupts: Arc, @@ -377,58 +346,6 @@ impl Instance { &*self.host_state } - /// Invoke the WebAssembly start function of the instance, if one is present. - fn invoke_start_function(&self, max_wasm_stack: usize) -> Result<(), InstantiationError> { - let start_index = match self.module.start_func { - Some(idx) => idx, - None => return Ok(()), - }; - - self.invoke_function_index(start_index, max_wasm_stack) - .map_err(InstantiationError::StartTrap) - } - - fn invoke_function_index( - &self, - callee_index: FuncIndex, - max_wasm_stack: usize, - ) -> Result<(), Trap> { - let (callee_address, callee_vmctx) = - match self.module.local.defined_func_index(callee_index) { - Some(defined_index) => { - let body = *self - .finished_functions - .get(defined_index) - .expect("function index is out of bounds"); - (body as *const _, self.vmctx_ptr()) - } - None => { - assert_lt!(callee_index.index(), self.module.local.num_imported_funcs); - let import = self.imported_function(callee_index); - (import.body, import.vmctx) - } - }; - - self.invoke_function(callee_vmctx, callee_address, max_wasm_stack) - } - - fn invoke_function( - &self, - callee_vmctx: *mut VMContext, - callee_address: *const VMFunctionBody, - max_wasm_stack: usize, - ) -> Result<(), Trap> { - // Make the call. - unsafe { - catch_traps(callee_vmctx, max_wasm_stack, || { - mem::transmute::< - *const VMFunctionBody, - unsafe extern "C" fn(*mut VMContext, *mut VMContext), - >(callee_address)(callee_vmctx, self.vmctx_ptr()) - }) - } - } - /// Return the offset from the vmctx pointer to its containing Instance. #[inline] pub(crate) fn vmctx_offset() -> isize { @@ -908,7 +825,6 @@ impl InstanceHandle { trampolines, dbg_jit_registration, host_state, - signal_handler: Cell::new(None), interrupts, vmctx: VMContext {}, }; @@ -988,7 +904,6 @@ impl InstanceHandle { pub unsafe fn initialize( &self, is_bulk_memory: bool, - max_wasm_stack: usize, data_initializers: &[DataInitializer<'_>], ) -> Result<(), InstantiationError> { // Check initializer bounds before initializing anything. Only do this @@ -1005,8 +920,6 @@ impl InstanceHandle { initialize_tables(self.instance())?; initialize_memories(self.instance(), data_initializers)?; - // And finally, invoke the start function. - self.instance().invoke_start_function(max_wasm_stack)?; Ok(()) } diff --git a/crates/runtime/src/lib.rs b/crates/runtime/src/lib.rs index 94816f6f13..3951355963 100644 --- a/crates/runtime/src/lib.rs +++ b/crates/runtime/src/lib.rs @@ -43,8 +43,9 @@ pub use crate::memory::{RuntimeLinearMemory, RuntimeMemoryCreator}; pub use crate::mmap::Mmap; pub use crate::sig_registry::SignatureRegistry; pub use crate::table::Table; -pub use crate::traphandlers::resume_panic; -pub use crate::traphandlers::{catch_traps, raise_lib_trap, raise_user_trap, Trap}; +pub use crate::traphandlers::{ + catch_traps, raise_lib_trap, raise_user_trap, resume_panic, SignalHandler, Trap, +}; pub use crate::vmcontext::{ VMCallerCheckedAnyfunc, VMContext, VMFunctionBody, VMFunctionImport, VMGlobalDefinition, VMGlobalImport, VMInterrupts, VMInvokeArgument, VMMemoryDefinition, VMMemoryImport, diff --git a/crates/runtime/src/traphandlers.rs b/crates/runtime/src/traphandlers.rs index 488dc3a059..d21568b90c 100644 --- a/crates/runtime/src/traphandlers.rs +++ b/crates/runtime/src/traphandlers.rs @@ -1,7 +1,6 @@ //! WebAssembly trap handling, which is built on top of the lower-level //! signalhandling mechanisms. -use crate::instance::{InstanceHandle, SignalHandler}; use crate::VMContext; use backtrace::Backtrace; use std::any::Any; @@ -26,6 +25,9 @@ cfg_if::cfg_if! { if #[cfg(unix)] { use std::mem::{self, MaybeUninit}; + /// Function which may handle custom signals while processing traps. + pub type SignalHandler<'a> = dyn Fn(libc::c_int, *const libc::siginfo_t, *const libc::c_void) -> bool + 'a; + static mut PREV_SIGSEGV: MaybeUninit = MaybeUninit::uninit(); static mut PREV_SIGBUS: MaybeUninit = MaybeUninit::uninit(); static mut PREV_SIGILL: MaybeUninit = MaybeUninit::uninit(); @@ -180,6 +182,9 @@ cfg_if::cfg_if! { use winapi::um::minwinbase::*; use winapi::vc::excpt::*; + /// Function which may handle custom signals while processing traps. + pub type SignalHandler<'a> = dyn Fn(winapi::um::winnt::PEXCEPTION_POINTERS) -> bool + 'a; + unsafe fn platform_init() { // our trap handler needs to go first, so that we can recover from // wasm faults and continue execution, so pass `1` as a true value @@ -361,6 +366,8 @@ impl Trap { pub unsafe fn catch_traps( vmctx: *mut VMContext, max_wasm_stack: usize, + is_wasm_code: impl Fn(usize) -> bool, + signal_handler: Option<&SignalHandler>, mut closure: F, ) -> Result<(), Trap> where @@ -370,7 +377,7 @@ where #[cfg(unix)] setup_unix_sigaltstack()?; - return CallThreadState::new(vmctx).with(max_wasm_stack, |cx| { + return CallThreadState::new(vmctx, &is_wasm_code, signal_handler).with(max_wasm_stack, |cx| { RegisterSetjmp( cx.jmp_buf.as_ptr(), call_closure::, @@ -388,12 +395,13 @@ where /// Temporary state stored on the stack which is registered in the `tls` module /// below for calls into wasm. -pub struct CallThreadState { +pub struct CallThreadState<'a> { unwind: Cell, jmp_buf: Cell<*const u8>, - prev: Option<*const CallThreadState>, vmctx: *mut VMContext, handling_trap: Cell, + is_wasm_code: &'a (dyn Fn(usize) -> bool + 'a), + signal_handler: Option<&'a SignalHandler<'a>>, } enum UnwindReason { @@ -404,54 +412,56 @@ enum UnwindReason { JitTrap { backtrace: Backtrace, pc: usize }, } -impl CallThreadState { - fn new(vmctx: *mut VMContext) -> CallThreadState { +impl<'a> CallThreadState<'a> { + fn new( + vmctx: *mut VMContext, + is_wasm_code: &'a (dyn Fn(usize) -> bool + 'a), + signal_handler: Option<&'a SignalHandler<'a>>, + ) -> CallThreadState<'a> { CallThreadState { unwind: Cell::new(UnwindReason::None), vmctx, jmp_buf: Cell::new(ptr::null()), - prev: None, handling_trap: Cell::new(false), + is_wasm_code, + signal_handler, } } fn with( - mut self, + self, max_wasm_stack: usize, closure: impl FnOnce(&CallThreadState) -> i32, ) -> Result<(), Trap> { - tls::with(|prev| { - self.prev = prev.map(|p| p as *const _); - let _reset = self.update_stack_limit(max_wasm_stack)?; - let ret = tls::set(&self, || closure(&self)); - match self.unwind.replace(UnwindReason::None) { - UnwindReason::None => { - debug_assert_eq!(ret, 1); - Ok(()) - } - UnwindReason::UserTrap(data) => { - debug_assert_eq!(ret, 0); - Err(Trap::User(data)) - } - UnwindReason::LibTrap(trap) => Err(trap), - UnwindReason::JitTrap { backtrace, pc } => { - debug_assert_eq!(ret, 0); - let maybe_interrupted = unsafe { - (*self.vmctx).instance().interrupts.stack_limit.load(SeqCst) - == wasmtime_environ::INTERRUPTED - }; - Err(Trap::Jit { - pc, - backtrace, - maybe_interrupted, - }) - } - UnwindReason::Panic(panic) => { - debug_assert_eq!(ret, 0); - std::panic::resume_unwind(panic) - } + let _reset = self.update_stack_limit(max_wasm_stack)?; + let ret = tls::set(&self, || closure(&self)); + match self.unwind.replace(UnwindReason::None) { + UnwindReason::None => { + debug_assert_eq!(ret, 1); + Ok(()) } - }) + UnwindReason::UserTrap(data) => { + debug_assert_eq!(ret, 0); + Err(Trap::User(data)) + } + UnwindReason::LibTrap(trap) => Err(trap), + UnwindReason::JitTrap { backtrace, pc } => { + debug_assert_eq!(ret, 0); + let maybe_interrupted = unsafe { + (*self.vmctx).instance().interrupts.stack_limit.load(SeqCst) + == wasmtime_environ::INTERRUPTED + }; + Err(Trap::Jit { + pc, + backtrace, + maybe_interrupted, + }) + } + UnwindReason::Panic(panic) => { + debug_assert_eq!(ret, 0); + std::panic::resume_unwind(panic) + } + } } /// Checks and/or initializes the wasm native call stack limit. @@ -535,18 +545,6 @@ impl CallThreadState { Ok(Reset(reset_stack_limit, &interrupts.stack_limit)) } - fn any_instance(&self, func: impl Fn(&InstanceHandle) -> bool) -> bool { - unsafe { - if func(&InstanceHandle::from_vmctx(self.vmctx)) { - return true; - } - match self.prev { - Some(prev) => (*prev).any_instance(func), - None => false, - } - } - } - fn unwind_with(&self, reason: UnwindReason) -> ! { self.unwind.replace(reason); unsafe { @@ -582,21 +580,25 @@ impl CallThreadState { if self.handling_trap.replace(true) { return ptr::null(); } + let _reset = ResetCell(&self.handling_trap, false); + + // If we haven't even started to handle traps yet, bail out. + if self.jmp_buf.get().is_null() { + return ptr::null(); + } + + // If this fault wasn't in wasm code, then it's not our problem + if !(self.is_wasm_code)(pc as usize) { + return ptr::null(); + } // First up see if any instance registered has a custom trap handler, // in which case run them all. If anything handles the trap then we // return that the trap was handled. - if self.any_instance(|i| { - let handler = match i.instance().signal_handler.replace(None) { - Some(handler) => handler, - None => return false, - }; - let result = call_handler(&handler); - i.instance().signal_handler.set(Some(handler)); - return result; - }) { - self.handling_trap.set(false); - return 1 as *const _; + if let Some(handler) = self.signal_handler { + if call_handler(handler) { + return 1 as *const _; + } } // TODO: stack overflow can happen at any random time (i.e. in malloc() @@ -607,7 +609,6 @@ impl CallThreadState { // doesn't trap. Then, if we have called some WebAssembly code, it // means the trap is stack overflow. if self.jmp_buf.get().is_null() { - self.handling_trap.set(false); return ptr::null(); } let backtrace = Backtrace::new_unresolved(); @@ -615,11 +616,18 @@ impl CallThreadState { backtrace, pc: pc as usize, }); - self.handling_trap.set(false); self.jmp_buf.get() } } +struct ResetCell<'a, T: Copy>(&'a Cell, T); + +impl Drop for ResetCell<'_, T> { + fn drop(&mut self) { + self.0.set(self.1); + } +} + // A private inner module for managing the TLS state that we require across // calls in wasm. The WebAssembly code is called from C++ and then a trap may // happen which requires us to read some contextual state to figure out what to @@ -628,14 +636,15 @@ impl CallThreadState { mod tls { use super::CallThreadState; use std::cell::Cell; + use std::mem; use std::ptr; - thread_local!(static PTR: Cell<*const CallThreadState> = Cell::new(ptr::null())); + thread_local!(static PTR: Cell<*const CallThreadState<'static>> = Cell::new(ptr::null())); /// Configures thread local state such that for the duration of the /// execution of `closure` any call to `with` will yield `ptr`, unless this /// is recursively called again. - pub fn set(ptr: &CallThreadState, closure: impl FnOnce() -> R) -> R { + pub fn set(ptr: &CallThreadState<'_>, closure: impl FnOnce() -> R) -> R { struct Reset<'a, T: Copy>(&'a Cell, T); impl Drop for Reset<'_, T> { @@ -645,6 +654,12 @@ mod tls { } PTR.with(|p| { + // Note that this extension of the lifetime to `'static` should be + // safe because we only ever access it below with an anonymous + // lifetime, meaning `'static` never leaks out of this module. + let ptr = unsafe { + mem::transmute::<*const CallThreadState<'_>, *const CallThreadState<'static>>(ptr) + }; let _r = Reset(p, p.replace(ptr)); closure() }) @@ -652,7 +667,7 @@ mod tls { /// Returns the last pointer configured with `set` above. Panics if `set` /// has not been previously called. - pub fn with(closure: impl FnOnce(Option<&CallThreadState>) -> R) -> R { + pub fn with(closure: impl FnOnce(Option<&CallThreadState<'_>>) -> R) -> R { PTR.with(|ptr| { let p = ptr.get(); unsafe { closure(if p.is_null() { None } else { Some(&*p) }) } diff --git a/tests/all/custom_signal_handler.rs b/tests/all/custom_signal_handler.rs index c75821ef17..ba5a96b162 100644 --- a/tests/all/custom_signal_handler.rs +++ b/tests/all/custom_signal_handler.rs @@ -3,7 +3,7 @@ mod tests { use anyhow::Result; use std::rc::Rc; use std::sync::atomic::{AtomicBool, Ordering}; - use wasmtime::unix::InstanceExt; + use wasmtime::unix::StoreExt; use wasmtime::*; const WAT1: &str = r#" @@ -96,7 +96,7 @@ mod tests { let (base, length) = set_up_memory(&instance); unsafe { - instance.set_signal_handler(move |signum, siginfo, _| { + store.set_signal_handler(move |signum, siginfo, _| { handle_sigsegv(base, length, signum, siginfo) }); } @@ -161,7 +161,7 @@ mod tests { unsafe { let (base1, length1) = set_up_memory(&instance1); - instance1.set_signal_handler({ + store.set_signal_handler({ let instance1_handler_triggered = instance1_handler_triggered.clone(); move |_signum, _siginfo, _context| { // Remove protections so the execution may resume @@ -180,31 +180,6 @@ mod tests { }); } - let instance2 = Instance::new(&module, &[]).expect("failed to instantiate module"); - let instance2_handler_triggered = Rc::new(AtomicBool::new(false)); - - unsafe { - let (base2, length2) = set_up_memory(&instance2); - - instance2.set_signal_handler({ - let instance2_handler_triggered = instance2_handler_triggered.clone(); - move |_signum, _siginfo, _context| { - // Remove protections so the execution may resume - libc::mprotect( - base2 as *mut libc::c_void, - length2, - libc::PROT_READ | libc::PROT_WRITE, - ); - instance2_handler_triggered.store(true, Ordering::SeqCst); - println!( - "Hello from instance2 signal handler! {}", - instance2_handler_triggered.load(Ordering::SeqCst) - ); - true - } - }); - } - // Invoke both instances and trigger both signal handlers // First instance1 @@ -222,6 +197,31 @@ mod tests { ); } + let instance2 = Instance::new(&module, &[]).expect("failed to instantiate module"); + let instance2_handler_triggered = Rc::new(AtomicBool::new(false)); + + unsafe { + let (base2, length2) = set_up_memory(&instance2); + + store.set_signal_handler({ + let instance2_handler_triggered = instance2_handler_triggered.clone(); + move |_signum, _siginfo, _context| { + // Remove protections so the execution may resume + libc::mprotect( + base2 as *mut libc::c_void, + length2, + libc::PROT_READ | libc::PROT_WRITE, + ); + instance2_handler_triggered.store(true, Ordering::SeqCst); + println!( + "Hello from instance2 signal handler! {}", + instance2_handler_triggered.load(Ordering::SeqCst) + ); + true + } + }); + } + // And then instance2 { let mut exports2 = instance2.exports(); @@ -249,7 +249,7 @@ mod tests { let instance1 = Instance::new(&module1, &[])?; let (base1, length1) = set_up_memory(&instance1); unsafe { - instance1.set_signal_handler(move |signum, siginfo, _| { + store.set_signal_handler(move |signum, siginfo, _| { println!("instance1"); handle_sigsegv(base1, length1, signum, siginfo) }); @@ -264,7 +264,7 @@ mod tests { // since 'instance2.run' calls 'instance1.read' we need to set up the signal handler to handle // SIGSEGV originating from within the memory of instance1 unsafe { - instance2.set_signal_handler(move |signum, siginfo, _| { + store.set_signal_handler(move |signum, siginfo, _| { handle_sigsegv(base1, length1, signum, siginfo) }); } diff --git a/tests/host_segfault.rs b/tests/host_segfault.rs index cc29fe87db..1884967287 100644 --- a/tests/host_segfault.rs +++ b/tests/host_segfault.rs @@ -59,6 +59,12 @@ fn main() { let _instance = Instance::new(&module, &[]).unwrap(); println!("stack overrun: {}", overrun_the_stack()); }), + ("segfault in a host function", || { + let store = Store::default(); + let module = Module::new(&store, r#"(import "" "" (func)) (start 0)"#).unwrap(); + let segfault = Func::wrap(&store, || segfault()); + Instance::new(&module, &[segfault.into()]).unwrap(); + }), ]; match env::var(VAR_NAME) { Ok(s) => {