Bring back per-thread lazy initialization (#2863)

* Bring back per-thread lazy initialization

Platforms Wasmtime supports may have per-thread initialization that
needs to run before WebAssembly. For example Unix needs to setup a
sigaltstack and macOS needs to set up mach ports. In #2757 this
per-thread setup was moved out of the invocation of a wasm function,
relying on the lack of Send for Store to initialize the thread at Store
creation time and never worry about it later.

This conflicted with [wasmtime's desired multithreading
story](https://github.com/bytecodealliance/wasmtime/pull/2812) so a new
[`Store::notify_switched_thread` was
added](https://github.com/bytecodealliance/wasmtime/pull/2822) to
explicitly indicate a Store has moved to another thread (if it unsafely
did so).

It turns out though that it's not always easy to determine when a
`Store` moves to a new thread. For example the Go bindings for Wasmtime
are generally unaware when a goroutine switches OS threads. This led to
https://github.com/bytecodealliance/wasmtime-go/issues/74 where a SIGILL
was left uncaught, making it appear that traps aren't working properly.

This commit revisits the decision in #2757 and moves per-thread
initialization back into the path of calling into WebAssembly. This is
differently from before, though, where there's still only one TLS access
on the path of calling into WebAssembly, unlike before where it was a
separate access. This allows us to get the speed benefits of #2757 as
well as the flexibility benefits of not having to explicitly move a
store between threads.

With this new ability this commit deletes the recently added
`Store::notify_switched_thread` method since it's no longer necessary.

* Fix a test compiling
This commit is contained in:
Alex Crichton
2021-04-28 12:08:27 -05:00
committed by GitHub
parent 207da989ac
commit 7ec073cef1
6 changed files with 72 additions and 113 deletions

View File

@@ -58,13 +58,12 @@ static mut IS_WASM_PC: fn(usize) -> bool = |_| false;
/// program counter is the pc of an actual wasm trap or not. This is then used /// program counter is the pc of an actual wasm trap or not. This is then used
/// to disambiguate faults that happen due to wasm and faults that happen due to /// to disambiguate faults that happen due to wasm and faults that happen due to
/// bugs in Rust or elsewhere. /// bugs in Rust or elsewhere.
pub fn init_traps(is_wasm_pc: fn(usize) -> bool) -> Result<(), Trap> { pub fn init_traps(is_wasm_pc: fn(usize) -> bool) {
static INIT: Once = Once::new(); static INIT: Once = Once::new();
INIT.call_once(|| unsafe { INIT.call_once(|| unsafe {
IS_WASM_PC = is_wasm_pc; IS_WASM_PC = is_wasm_pc;
sys::platform_init(); sys::platform_init();
}); });
sys::lazy_per_thread_init()
} }
/// Raises a user-defined trap immediately. /// Raises a user-defined trap immediately.
@@ -256,7 +255,7 @@ impl<'a> CallThreadState<'a> {
} }
fn with(self, closure: impl FnOnce(&CallThreadState) -> i32) -> Result<(), Trap> { fn with(self, closure: impl FnOnce(&CallThreadState) -> i32) -> Result<(), Trap> {
let ret = tls::set(&self, || closure(&self)); let ret = tls::set(&self, || closure(&self))?;
if ret != 0 { if ret != 0 {
return Ok(()); return Ok(());
} }
@@ -366,6 +365,7 @@ impl<T: Copy> Drop for ResetCell<'_, T> {
// the caller to the trap site. // the caller to the trap site.
mod tls { mod tls {
use super::CallThreadState; use super::CallThreadState;
use crate::Trap;
use std::mem; use std::mem;
use std::ptr; use std::ptr;
@@ -384,21 +384,38 @@ mod tls {
// these TLS values when the runtime may have crossed threads. // these TLS values when the runtime may have crossed threads.
mod raw { mod raw {
use super::CallThreadState; use super::CallThreadState;
use crate::Trap;
use std::cell::Cell; use std::cell::Cell;
use std::ptr; use std::ptr;
pub type Ptr = *const CallThreadState<'static>; pub type Ptr = *const CallThreadState<'static>;
thread_local!(static PTR: Cell<Ptr> = Cell::new(ptr::null())); // The first entry here is the `Ptr` which is what's used as part of the
// public interface of this module. The second entry is a boolean which
// allows the runtime to perform per-thread initialization if necessary
// for handling traps (e.g. setting up ports on macOS and sigaltstack on
// Unix).
thread_local!(static PTR: Cell<(Ptr, bool)> = Cell::new((ptr::null(), false)));
#[inline(never)] // see module docs for why this is here #[inline(never)] // see module docs for why this is here
pub fn replace(val: Ptr) -> Ptr { pub fn replace(val: Ptr) -> Result<Ptr, Trap> {
PTR.with(|p| p.replace(val)) PTR.with(|p| {
// When a new value is configured that means that we may be
// entering WebAssembly so check to see if this thread has
// performed per-thread initialization for traps.
let (prev, mut initialized) = p.get();
if !initialized {
super::super::sys::lazy_per_thread_init()?;
initialized = true;
}
p.set((val, initialized));
Ok(prev)
})
} }
#[inline(never)] // see module docs for why this is here #[inline(never)] // see module docs for why this is here
pub fn get() -> Ptr { pub fn get() -> Ptr {
PTR.with(|p| p.get()) PTR.with(|p| p.get().0)
} }
} }
@@ -412,7 +429,7 @@ mod tls {
/// ///
/// This is not a safe operation since it's intended to only be used /// This is not a safe operation since it's intended to only be used
/// with stack switching found with fibers and async wasmtime. /// with stack switching found with fibers and async wasmtime.
pub unsafe fn take() -> TlsRestore { pub unsafe fn take() -> Result<TlsRestore, Trap> {
// Our tls pointer must be set at this time, and it must not be // Our tls pointer must be set at this time, and it must not be
// null. We need to restore the previous pointer since we're // null. We need to restore the previous pointer since we're
// removing ourselves from the call-stack, and in the process we // removing ourselves from the call-stack, and in the process we
@@ -421,8 +438,8 @@ mod tls {
let raw = raw::get(); let raw = raw::get();
assert!(!raw.is_null()); assert!(!raw.is_null());
let prev = (*raw).prev.replace(ptr::null()); let prev = (*raw).prev.replace(ptr::null());
raw::replace(prev); raw::replace(prev)?;
TlsRestore(raw) Ok(TlsRestore(raw))
} }
/// Restores a previous tls state back into this thread's TLS. /// Restores a previous tls state back into this thread's TLS.
@@ -430,17 +447,12 @@ mod tls {
/// This is unsafe because it's intended to only be used within the /// This is unsafe because it's intended to only be used within the
/// context of stack switching within wasmtime. /// context of stack switching within wasmtime.
pub unsafe fn replace(self) -> Result<(), super::Trap> { pub unsafe fn replace(self) -> Result<(), super::Trap> {
// When replacing to the previous value of TLS, we might have
// crossed a thread: make sure the trap-handling lazy initializer
// runs.
super::sys::lazy_per_thread_init()?;
// We need to configure our previous TLS pointer to whatever is in // We need to configure our previous TLS pointer to whatever is in
// TLS at this time, and then we set the current state to ourselves. // TLS at this time, and then we set the current state to ourselves.
let prev = raw::get(); let prev = raw::get();
assert!((*self.0).prev.get().is_null()); assert!((*self.0).prev.get().is_null());
(*self.0).prev.set(prev); (*self.0).prev.set(prev);
raw::replace(self.0); raw::replace(self.0)?;
Ok(()) Ok(())
} }
} }
@@ -448,13 +460,14 @@ mod tls {
/// Configures thread local state such that for the duration of the /// Configures thread local state such that for the duration of the
/// execution of `closure` any call to `with` will yield `ptr`, unless this /// execution of `closure` any call to `with` will yield `ptr`, unless this
/// is recursively called again. /// is recursively called again.
pub fn set<R>(state: &CallThreadState<'_>, closure: impl FnOnce() -> R) -> R { pub fn set<R>(state: &CallThreadState<'_>, closure: impl FnOnce() -> R) -> Result<R, Trap> {
struct Reset<'a, 'b>(&'a CallThreadState<'b>); struct Reset<'a, 'b>(&'a CallThreadState<'b>);
impl Drop for Reset<'_, '_> { impl Drop for Reset<'_, '_> {
#[inline] #[inline]
fn drop(&mut self) { fn drop(&mut self) {
raw::replace(self.0.prev.replace(ptr::null())); raw::replace(self.0.prev.replace(ptr::null()))
.expect("tls should be previously initialized");
} }
} }
@@ -464,10 +477,10 @@ mod tls {
let ptr = unsafe { let ptr = unsafe {
mem::transmute::<*const CallThreadState<'_>, *const CallThreadState<'static>>(state) mem::transmute::<*const CallThreadState<'_>, *const CallThreadState<'static>>(state)
}; };
let prev = raw::replace(ptr); let prev = raw::replace(ptr)?;
state.prev.set(prev); state.prev.set(prev);
let _reset = Reset(state); let _reset = Reset(state);
closure() Ok(closure())
} }
/// Returns the last pointer configured with `set` above. Panics if `set` /// Returns the last pointer configured with `set` above. Panics if `set`

View File

@@ -42,7 +42,6 @@ use mach::message::*;
use mach::port::*; use mach::port::*;
use mach::thread_act::*; use mach::thread_act::*;
use mach::traps::*; use mach::traps::*;
use std::cell::Cell;
use std::mem; use std::mem;
use std::thread; use std::thread;
@@ -425,26 +424,16 @@ impl Drop for ClosePort {
/// task-level port which is where we'd expected things like breakpad/crashpad /// task-level port which is where we'd expected things like breakpad/crashpad
/// exception handlers to get registered. /// exception handlers to get registered.
pub fn lazy_per_thread_init() -> Result<(), Trap> { pub fn lazy_per_thread_init() -> Result<(), Trap> {
thread_local! { unsafe {
static PORTS_SET: Cell<bool> = Cell::new(false); assert!(WASMTIME_PORT != MACH_PORT_NULL);
let kret = thread_set_exception_ports(
MY_PORT.with(|p| p.0),
EXC_MASK_BAD_ACCESS | EXC_MASK_BAD_INSTRUCTION,
WASMTIME_PORT,
EXCEPTION_DEFAULT | MACH_EXCEPTION_CODES,
mach_addons::THREAD_STATE_NONE,
);
assert_eq!(kret, KERN_SUCCESS, "failed to set thread exception port");
} }
PORTS_SET.with(|ports| {
if ports.replace(true) {
return;
}
unsafe {
assert!(WASMTIME_PORT != MACH_PORT_NULL);
let kret = thread_set_exception_ports(
MY_PORT.with(|p| p.0),
EXC_MASK_BAD_ACCESS | EXC_MASK_BAD_INSTRUCTION,
WASMTIME_PORT,
EXCEPTION_DEFAULT | MACH_EXCEPTION_CODES,
mach_addons::THREAD_STATE_NONE,
);
assert_eq!(kret, KERN_SUCCESS, "failed to set thread exception port");
}
});
Ok(()) Ok(())
} }

View File

@@ -154,41 +154,35 @@ unsafe fn get_pc(cx: *mut libc::c_void) -> *const u8 {
/// and registering our own alternate stack that is large enough and has a guard /// and registering our own alternate stack that is large enough and has a guard
/// page. /// page.
pub fn lazy_per_thread_init() -> Result<(), Trap> { pub fn lazy_per_thread_init() -> Result<(), Trap> {
// This thread local is purely used to register a `Stack` to get deallocated
// when the thread exists. Otherwise this function is only ever called at
// most once per-thread.
thread_local! { thread_local! {
/// Thread-local state is lazy-initialized on the first time it's used, static STACK: RefCell<Option<Stack>> = RefCell::new(None);
/// and dropped when the thread exits.
static TLS: RefCell<Tls> = RefCell::new(Tls::None);
} }
/// The size of the sigaltstack (not including the guard, which will be /// The size of the sigaltstack (not including the guard, which will be
/// added). Make this large enough to run our signal handlers. /// added). Make this large enough to run our signal handlers.
const MIN_STACK_SIZE: usize = 16 * 4096; const MIN_STACK_SIZE: usize = 16 * 4096;
enum Tls { struct Stack {
None, mmap_ptr: *mut libc::c_void,
Allocated { mmap_size: usize,
mmap_ptr: *mut libc::c_void,
mmap_size: usize,
},
BigEnough,
} }
return TLS.with(|slot| unsafe { return STACK.with(|s| {
let mut slot = slot.borrow_mut(); *s.borrow_mut() = unsafe { allocate_sigaltstack()? };
match *slot { Ok(())
Tls::None => {} });
// already checked
_ => return Ok(()),
}
unsafe fn allocate_sigaltstack() -> Result<Option<Stack>, Trap> {
// Check to see if the existing sigaltstack, if it exists, is big // Check to see if the existing sigaltstack, if it exists, is big
// enough. If so we don't need to allocate our own. // enough. If so we don't need to allocate our own.
let mut old_stack = mem::zeroed(); let mut old_stack = mem::zeroed();
let r = libc::sigaltstack(ptr::null(), &mut old_stack); let r = libc::sigaltstack(ptr::null(), &mut old_stack);
assert_eq!(r, 0, "learning about sigaltstack failed"); assert_eq!(r, 0, "learning about sigaltstack failed");
if old_stack.ss_flags & libc::SS_DISABLE == 0 && old_stack.ss_size >= MIN_STACK_SIZE { if old_stack.ss_flags & libc::SS_DISABLE == 0 && old_stack.ss_size >= MIN_STACK_SIZE {
*slot = Tls::BigEnough; return Ok(None);
return Ok(());
} }
// ... but failing that we need to allocate our own, so do all that // ... but failing that we need to allocate our own, so do all that
@@ -226,25 +220,17 @@ pub fn lazy_per_thread_init() -> Result<(), Trap> {
let r = libc::sigaltstack(&new_stack, ptr::null_mut()); let r = libc::sigaltstack(&new_stack, ptr::null_mut());
assert_eq!(r, 0, "registering new sigaltstack failed"); assert_eq!(r, 0, "registering new sigaltstack failed");
*slot = Tls::Allocated { Ok(Some(Stack {
mmap_ptr: ptr, mmap_ptr: ptr,
mmap_size: alloc_size, mmap_size: alloc_size,
}; }))
Ok(()) }
});
impl Drop for Tls { impl Drop for Stack {
fn drop(&mut self) { fn drop(&mut self) {
let (ptr, size) = match self {
Tls::Allocated {
mmap_ptr,
mmap_size,
} => (*mmap_ptr, *mmap_size),
_ => return,
};
unsafe { unsafe {
// Deallocate the stack memory. // Deallocate the stack memory.
let r = libc::munmap(ptr, size); let r = libc::munmap(self.mmap_ptr, self.mmap_size);
debug_assert_eq!(r, 0, "munmap failed during thread shutdown"); debug_assert_eq!(r, 0, "munmap failed during thread shutdown");
} }
} }

View File

@@ -159,13 +159,10 @@ impl Store {
} }
fn new_(engine: &Engine, limiter: Option<Rc<dyn wasmtime_runtime::ResourceLimiter>>) -> Self { fn new_(engine: &Engine, limiter: Option<Rc<dyn wasmtime_runtime::ResourceLimiter>>) -> Self {
// Ensure that wasmtime_runtime's signal handlers are configured. Note // Ensure that wasmtime_runtime's signal handlers are configured. This
// that at the `Store` level it means we should perform this // is the per-program initialization required for handling traps, such
// once-per-thread. Platforms like Unix, however, only require this // as configuring signals, vectored exception handlers, etc.
// once-per-program. In any case this is safe to call many times and wasmtime_runtime::init_traps(crate::module::GlobalModuleRegistry::is_wasm_pc);
// each one that's not relevant just won't do anything.
wasmtime_runtime::init_traps(crate::module::GlobalModuleRegistry::is_wasm_pc)
.expect("failed to initialize trap handling");
Self { Self {
inner: Rc::new(StoreInner { inner: Rc::new(StoreInner {
@@ -451,25 +448,6 @@ impl Store {
&self.inner.modules &self.inner.modules
} }
/// Notifies that the current Store (and all referenced entities) has been moved over to a
/// different thread.
///
/// See also the multithreading documentation for more details:
/// <https://docs.wasmtime.dev/examples-rust-multithreading.html>.
///
/// # Safety
///
/// In general, it's not possible to move a `Store` to a different thread, because it isn't `Send`.
/// That being said, it is possible to create an unsafe `Send` wrapper over a `Store`, assuming
/// the safety guidelines exposed in the multithreading documentation have been applied. So it
/// is in general unnecessary to do this, if you're not doing unsafe things.
///
/// It is fine to call this several times: only the first call will have an effect.
pub unsafe fn notify_switched_thread(&self) {
wasmtime_runtime::init_traps(crate::module::GlobalModuleRegistry::is_wasm_pc)
.expect("failed to initialize per-threads traps");
}
#[inline] #[inline]
pub(crate) fn module_info_lookup(&self) -> &dyn wasmtime_runtime::ModuleInfoLookup { pub(crate) fn module_info_lookup(&self) -> &dyn wasmtime_runtime::ModuleInfoLookup {
self.inner.as_ref() self.inner.as_ref()
@@ -673,7 +651,8 @@ impl Store {
} }
unsafe { unsafe {
let before = wasmtime_runtime::TlsRestore::take(); let before = wasmtime_runtime::TlsRestore::take()
.map_err(|e| Trap::from_runtime(self, e))?;
let res = (*suspend).suspend(()); let res = (*suspend).suspend(());
before.replace().map_err(|e| Trap::from_runtime(self, e))?; before.replace().map_err(|e| Trap::from_runtime(self, e))?;
res?; res?;

View File

@@ -129,16 +129,11 @@ some possibilities include:
`Store::set` or `Func::wrap`) implement the `Send` trait. `Store::set` or `Func::wrap`) implement the `Send` trait.
If these requirements are met it is technically safe to move a store and its If these requirements are met it is technically safe to move a store and its
objects between threads. When you move a store to another thread, it is objects between threads. The reason that this strategy isn't recommended,
required that you run the `Store::notify_switched_thread()` method after the however, is that you will receive no assistance from the Rust compiler in
store has landed on the new thread, so that per-thread initialization is verifying that the transfer across threads is indeed actually safe. This will
correctly re-run. Failure to do so may cause wasm traps to crash the whole require auditing your embedding of Wasmtime itself to ensure it meets these
application. requirements.
The reason that this strategy isn't recommended, however, is that you will
receive no assistance from the Rust compiler in verifying that the transfer
across threads is indeed actually safe. This will require auditing your
embedding of Wasmtime itself to ensure it meets these requirements.
It's important to note that the requirements here also apply to the futures It's important to note that the requirements here also apply to the futures
returned from `Func::call_async`. These futures are not `Send` due to them returned from `Func::call_async`. These futures are not `Send` due to them

View File

@@ -616,9 +616,6 @@ fn multithreaded_traps() -> Result<()> {
let handle = std::thread::spawn(move || { let handle = std::thread::spawn(move || {
let instance = instance.inner; let instance = instance.inner;
unsafe {
instance.store().notify_switched_thread();
}
assert!(instance assert!(instance
.get_typed_func::<(), ()>("run") .get_typed_func::<(), ()>("run")
.unwrap() .unwrap()