Bring back per-thread lazy initialization (#2863)

* Bring back per-thread lazy initialization

Platforms Wasmtime supports may have per-thread initialization that
needs to run before WebAssembly. For example Unix needs to setup a
sigaltstack and macOS needs to set up mach ports. In #2757 this
per-thread setup was moved out of the invocation of a wasm function,
relying on the lack of Send for Store to initialize the thread at Store
creation time and never worry about it later.

This conflicted with [wasmtime's desired multithreading
story](https://github.com/bytecodealliance/wasmtime/pull/2812) so a new
[`Store::notify_switched_thread` was
added](https://github.com/bytecodealliance/wasmtime/pull/2822) to
explicitly indicate a Store has moved to another thread (if it unsafely
did so).

It turns out though that it's not always easy to determine when a
`Store` moves to a new thread. For example the Go bindings for Wasmtime
are generally unaware when a goroutine switches OS threads. This led to
https://github.com/bytecodealliance/wasmtime-go/issues/74 where a SIGILL
was left uncaught, making it appear that traps aren't working properly.

This commit revisits the decision in #2757 and moves per-thread
initialization back into the path of calling into WebAssembly. This is
differently from before, though, where there's still only one TLS access
on the path of calling into WebAssembly, unlike before where it was a
separate access. This allows us to get the speed benefits of #2757 as
well as the flexibility benefits of not having to explicitly move a
store between threads.

With this new ability this commit deletes the recently added
`Store::notify_switched_thread` method since it's no longer necessary.

* Fix a test compiling
This commit is contained in:
Alex Crichton
2021-04-28 12:08:27 -05:00
committed by GitHub
parent 207da989ac
commit 7ec073cef1
6 changed files with 72 additions and 113 deletions

View File

@@ -58,13 +58,12 @@ static mut IS_WASM_PC: fn(usize) -> bool = |_| false;
/// program counter is the pc of an actual wasm trap or not. This is then used
/// to disambiguate faults that happen due to wasm and faults that happen due to
/// bugs in Rust or elsewhere.
pub fn init_traps(is_wasm_pc: fn(usize) -> bool) -> Result<(), Trap> {
pub fn init_traps(is_wasm_pc: fn(usize) -> bool) {
static INIT: Once = Once::new();
INIT.call_once(|| unsafe {
IS_WASM_PC = is_wasm_pc;
sys::platform_init();
});
sys::lazy_per_thread_init()
}
/// Raises a user-defined trap immediately.
@@ -256,7 +255,7 @@ impl<'a> CallThreadState<'a> {
}
fn with(self, closure: impl FnOnce(&CallThreadState) -> i32) -> Result<(), Trap> {
let ret = tls::set(&self, || closure(&self));
let ret = tls::set(&self, || closure(&self))?;
if ret != 0 {
return Ok(());
}
@@ -366,6 +365,7 @@ impl<T: Copy> Drop for ResetCell<'_, T> {
// the caller to the trap site.
mod tls {
use super::CallThreadState;
use crate::Trap;
use std::mem;
use std::ptr;
@@ -384,21 +384,38 @@ mod tls {
// these TLS values when the runtime may have crossed threads.
mod raw {
use super::CallThreadState;
use crate::Trap;
use std::cell::Cell;
use std::ptr;
pub type Ptr = *const CallThreadState<'static>;
thread_local!(static PTR: Cell<Ptr> = Cell::new(ptr::null()));
// The first entry here is the `Ptr` which is what's used as part of the
// public interface of this module. The second entry is a boolean which
// allows the runtime to perform per-thread initialization if necessary
// for handling traps (e.g. setting up ports on macOS and sigaltstack on
// Unix).
thread_local!(static PTR: Cell<(Ptr, bool)> = Cell::new((ptr::null(), false)));
#[inline(never)] // see module docs for why this is here
pub fn replace(val: Ptr) -> Ptr {
PTR.with(|p| p.replace(val))
pub fn replace(val: Ptr) -> Result<Ptr, Trap> {
PTR.with(|p| {
// When a new value is configured that means that we may be
// entering WebAssembly so check to see if this thread has
// performed per-thread initialization for traps.
let (prev, mut initialized) = p.get();
if !initialized {
super::super::sys::lazy_per_thread_init()?;
initialized = true;
}
p.set((val, initialized));
Ok(prev)
})
}
#[inline(never)] // see module docs for why this is here
pub fn get() -> Ptr {
PTR.with(|p| p.get())
PTR.with(|p| p.get().0)
}
}
@@ -412,7 +429,7 @@ mod tls {
///
/// This is not a safe operation since it's intended to only be used
/// with stack switching found with fibers and async wasmtime.
pub unsafe fn take() -> TlsRestore {
pub unsafe fn take() -> Result<TlsRestore, Trap> {
// Our tls pointer must be set at this time, and it must not be
// null. We need to restore the previous pointer since we're
// removing ourselves from the call-stack, and in the process we
@@ -421,8 +438,8 @@ mod tls {
let raw = raw::get();
assert!(!raw.is_null());
let prev = (*raw).prev.replace(ptr::null());
raw::replace(prev);
TlsRestore(raw)
raw::replace(prev)?;
Ok(TlsRestore(raw))
}
/// Restores a previous tls state back into this thread's TLS.
@@ -430,17 +447,12 @@ mod tls {
/// This is unsafe because it's intended to only be used within the
/// context of stack switching within wasmtime.
pub unsafe fn replace(self) -> Result<(), super::Trap> {
// When replacing to the previous value of TLS, we might have
// crossed a thread: make sure the trap-handling lazy initializer
// runs.
super::sys::lazy_per_thread_init()?;
// We need to configure our previous TLS pointer to whatever is in
// TLS at this time, and then we set the current state to ourselves.
let prev = raw::get();
assert!((*self.0).prev.get().is_null());
(*self.0).prev.set(prev);
raw::replace(self.0);
raw::replace(self.0)?;
Ok(())
}
}
@@ -448,13 +460,14 @@ mod tls {
/// Configures thread local state such that for the duration of the
/// execution of `closure` any call to `with` will yield `ptr`, unless this
/// is recursively called again.
pub fn set<R>(state: &CallThreadState<'_>, closure: impl FnOnce() -> R) -> R {
pub fn set<R>(state: &CallThreadState<'_>, closure: impl FnOnce() -> R) -> Result<R, Trap> {
struct Reset<'a, 'b>(&'a CallThreadState<'b>);
impl Drop for Reset<'_, '_> {
#[inline]
fn drop(&mut self) {
raw::replace(self.0.prev.replace(ptr::null()));
raw::replace(self.0.prev.replace(ptr::null()))
.expect("tls should be previously initialized");
}
}
@@ -464,10 +477,10 @@ mod tls {
let ptr = unsafe {
mem::transmute::<*const CallThreadState<'_>, *const CallThreadState<'static>>(state)
};
let prev = raw::replace(ptr);
let prev = raw::replace(ptr)?;
state.prev.set(prev);
let _reset = Reset(state);
closure()
Ok(closure())
}
/// Returns the last pointer configured with `set` above. Panics if `set`