Fix some issues around TLS management with async (#2709)

This commit fixes a few issues around managing the thread-local state of
a wasmtime thread. We intentionally only have a singular TLS variable in
the whole world, and the problem is that when stack-switching off an
async thread we were not restoring the previous TLS state. This is
necessary in two cases:

* Futures aren't guaranteed to be polled/completed in a stack-like
  fashion. If a poll sees that a future isn't ready then we may resume
  execution in a previous wasm context that ends up needing the TLS
  information.

* Futures can also cross threads (when the whole store crosses threads)
  and we need to save/restore TLS state from the thread we're coming
  from and the thread that we're going to.

The stack switching issue necessitates some more glue around suspension
and resumption of a stack to ensure we save/restore the TLS state on
both sides. The thread issue, however, also necessitates that we use
`#[inline(never)]` on TLS access functions and never have TLS borrows
live across a function which could result in running arbitrary code (as
was the case for the `tls::set` function.
This commit is contained in:
Alex Crichton
2021-03-11 11:32:33 -06:00
committed by GitHub
parent 54c07d8f16
commit 918c012d00
4 changed files with 236 additions and 22 deletions

View File

@@ -48,7 +48,7 @@ pub use crate::mmap::Mmap;
pub use crate::table::{Table, TableElement};
pub use crate::traphandlers::{
catch_traps, init_traps, raise_lib_trap, raise_user_trap, resume_panic, with_last_info,
SignalHandler, Trap, TrapInfo,
SignalHandler, TlsRestore, Trap, TrapInfo,
};
pub use crate::vmcontext::{
VMCallerCheckedAnyfunc, VMContext, VMFunctionBody, VMFunctionImport, VMGlobalDefinition,

View File

@@ -12,6 +12,8 @@ use std::sync::atomic::{AtomicUsize, Ordering::SeqCst};
use std::sync::Once;
use wasmtime_environ::ir;
pub use self::tls::TlsRestore;
extern "C" {
fn RegisterSetjmp(
jmp_buf: *mut *const u8,
@@ -491,6 +493,7 @@ pub struct CallThreadState<'a> {
jmp_buf: Cell<*const u8>,
handling_trap: Cell<bool>,
trap_info: &'a (dyn TrapInfo + 'a),
prev: Cell<tls::Ptr>,
}
/// A package of functionality needed by `catch_traps` to figure out what to do
@@ -541,6 +544,7 @@ impl<'a> CallThreadState<'a> {
jmp_buf: Cell::new(ptr::null()),
handling_trap: Cell::new(false),
trap_info,
prev: Cell::new(ptr::null()),
}
}
@@ -753,43 +757,108 @@ impl<T: Copy> Drop for ResetCell<'_, T> {
// the caller to the trap site.
mod tls {
use super::CallThreadState;
use std::cell::Cell;
use std::mem;
use std::ptr;
thread_local!(static PTR: Cell<*const CallThreadState<'static>> = Cell::new(ptr::null()));
pub use raw::Ptr;
// An even *more* inner module for dealing with TLS. This actually has the
// thread local variable and has functions to access the variable.
//
// Note that this is specially done to fully encapsulate that the accessors
// for tls must not be inlined. Wasmtime's async support employs stack
// switching which can resume execution on different OS threads. This means
// that borrows of our TLS pointer must never live across accesses because
// otherwise the access may be split across two threads and cause unsafety.
//
// This also means that extra care is taken by the runtime to save/restore
// these TLS values when the runtime may have crossed threads.
mod raw {
use super::CallThreadState;
use std::cell::Cell;
use std::ptr;
pub type Ptr = *const CallThreadState<'static>;
thread_local!(static PTR: Cell<Ptr> = Cell::new(ptr::null()));
#[inline(never)] // see module docs for why this is here
pub fn replace(val: Ptr) -> Ptr {
PTR.with(|p| p.replace(val))
}
#[inline(never)] // see module docs for why this is here
pub fn get() -> Ptr {
PTR.with(|p| p.get())
}
}
/// Opaque state used to help control TLS state across stack switches for
/// async support.
pub struct TlsRestore(raw::Ptr);
impl TlsRestore {
/// Takes the TLS state that is currently configured and returns a
/// token that is used to replace it later.
///
/// This is not a safe operation since it's intended to only be used
/// with stack switching found with fibers and async wasmtime.
pub unsafe fn take() -> TlsRestore {
// Our tls pointer must be set at this time, and it must not be
// null. We need to restore the previous pointer since we're
// removing ourselves from the call-stack, and in the process we
// null out our own previous field for safety in case it's
// accidentally used later.
let raw = raw::get();
assert!(!raw.is_null());
let prev = (*raw).prev.replace(ptr::null());
raw::replace(prev);
TlsRestore(raw)
}
/// Restores a previous tls state back into this thread's TLS.
///
/// This is unsafe because it's intended to only be used within the
/// context of stack switching within wasmtime.
pub unsafe fn replace(self) {
// We need to configure our previous TLS pointer to whatever is in
// TLS at this time, and then we set the current state to ourselves.
let prev = raw::get();
assert!((*self.0).prev.get().is_null());
(*self.0).prev.set(prev);
raw::replace(self.0);
}
}
/// Configures thread local state such that for the duration of the
/// execution of `closure` any call to `with` will yield `ptr`, unless this
/// is recursively called again.
pub fn set<R>(ptr: &CallThreadState<'_>, closure: impl FnOnce() -> R) -> R {
struct Reset<'a, T: Copy>(&'a Cell<T>, T);
pub fn set<R>(state: &CallThreadState<'_>, closure: impl FnOnce() -> R) -> R {
struct Reset<'a, 'b>(&'a CallThreadState<'b>);
impl<T: Copy> Drop for Reset<'_, T> {
impl Drop for Reset<'_, '_> {
fn drop(&mut self) {
self.0.set(self.1);
raw::replace(self.0.prev.replace(ptr::null()));
}
}
PTR.with(|p| {
// Note that this extension of the lifetime to `'static` should be
// safe because we only ever access it below with an anonymous
// lifetime, meaning `'static` never leaks out of this module.
let ptr = unsafe {
mem::transmute::<*const CallThreadState<'_>, *const CallThreadState<'static>>(ptr)
};
let _r = Reset(p, p.replace(ptr));
closure()
})
// Note that this extension of the lifetime to `'static` should be
// safe because we only ever access it below with an anonymous
// lifetime, meaning `'static` never leaks out of this module.
let ptr = unsafe {
mem::transmute::<*const CallThreadState<'_>, *const CallThreadState<'static>>(state)
};
let prev = raw::replace(ptr);
state.prev.set(prev);
let _reset = Reset(state);
closure()
}
/// Returns the last pointer configured with `set` above. Panics if `set`
/// has not been previously called.
pub fn with<R>(closure: impl FnOnce(Option<&CallThreadState<'_>>) -> R) -> R {
PTR.with(|ptr| {
let p = ptr.get();
unsafe { closure(if p.is_null() { None } else { Some(&*p) }) }
})
let p = raw::get();
unsafe { closure(if p.is_null() { None } else { Some(&*p) }) }
}
}

View File

@@ -717,7 +717,10 @@ impl Store {
Poll::Pending => {}
}
unsafe {
(*suspend).suspend(())?;
let before = wasmtime_runtime::TlsRestore::take();
let res = (*suspend).suspend(());
before.replace();
res?;
}
}
}