Fix some issues around TLS management with async (#2709)
This commit fixes a few issues around managing the thread-local state of a wasmtime thread. We intentionally only have a singular TLS variable in the whole world, and the problem is that when stack-switching off an async thread we were not restoring the previous TLS state. This is necessary in two cases: * Futures aren't guaranteed to be polled/completed in a stack-like fashion. If a poll sees that a future isn't ready then we may resume execution in a previous wasm context that ends up needing the TLS information. * Futures can also cross threads (when the whole store crosses threads) and we need to save/restore TLS state from the thread we're coming from and the thread that we're going to. The stack switching issue necessitates some more glue around suspension and resumption of a stack to ensure we save/restore the TLS state on both sides. The thread issue, however, also necessitates that we use `#[inline(never)]` on TLS access functions and never have TLS borrows live across a function which could result in running arbitrary code (as was the case for the `tls::set` function.
This commit is contained in:
@@ -48,7 +48,7 @@ pub use crate::mmap::Mmap;
|
|||||||
pub use crate::table::{Table, TableElement};
|
pub use crate::table::{Table, TableElement};
|
||||||
pub use crate::traphandlers::{
|
pub use crate::traphandlers::{
|
||||||
catch_traps, init_traps, raise_lib_trap, raise_user_trap, resume_panic, with_last_info,
|
catch_traps, init_traps, raise_lib_trap, raise_user_trap, resume_panic, with_last_info,
|
||||||
SignalHandler, Trap, TrapInfo,
|
SignalHandler, TlsRestore, Trap, TrapInfo,
|
||||||
};
|
};
|
||||||
pub use crate::vmcontext::{
|
pub use crate::vmcontext::{
|
||||||
VMCallerCheckedAnyfunc, VMContext, VMFunctionBody, VMFunctionImport, VMGlobalDefinition,
|
VMCallerCheckedAnyfunc, VMContext, VMFunctionBody, VMFunctionImport, VMGlobalDefinition,
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ use std::sync::atomic::{AtomicUsize, Ordering::SeqCst};
|
|||||||
use std::sync::Once;
|
use std::sync::Once;
|
||||||
use wasmtime_environ::ir;
|
use wasmtime_environ::ir;
|
||||||
|
|
||||||
|
pub use self::tls::TlsRestore;
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
fn RegisterSetjmp(
|
fn RegisterSetjmp(
|
||||||
jmp_buf: *mut *const u8,
|
jmp_buf: *mut *const u8,
|
||||||
@@ -491,6 +493,7 @@ pub struct CallThreadState<'a> {
|
|||||||
jmp_buf: Cell<*const u8>,
|
jmp_buf: Cell<*const u8>,
|
||||||
handling_trap: Cell<bool>,
|
handling_trap: Cell<bool>,
|
||||||
trap_info: &'a (dyn TrapInfo + 'a),
|
trap_info: &'a (dyn TrapInfo + 'a),
|
||||||
|
prev: Cell<tls::Ptr>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A package of functionality needed by `catch_traps` to figure out what to do
|
/// A package of functionality needed by `catch_traps` to figure out what to do
|
||||||
@@ -541,6 +544,7 @@ impl<'a> CallThreadState<'a> {
|
|||||||
jmp_buf: Cell::new(ptr::null()),
|
jmp_buf: Cell::new(ptr::null()),
|
||||||
handling_trap: Cell::new(false),
|
handling_trap: Cell::new(false),
|
||||||
trap_info,
|
trap_info,
|
||||||
|
prev: Cell::new(ptr::null()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -753,43 +757,108 @@ impl<T: Copy> Drop for ResetCell<'_, T> {
|
|||||||
// the caller to the trap site.
|
// the caller to the trap site.
|
||||||
mod tls {
|
mod tls {
|
||||||
use super::CallThreadState;
|
use super::CallThreadState;
|
||||||
use std::cell::Cell;
|
|
||||||
use std::mem;
|
use std::mem;
|
||||||
use std::ptr;
|
use std::ptr;
|
||||||
|
|
||||||
thread_local!(static PTR: Cell<*const CallThreadState<'static>> = Cell::new(ptr::null()));
|
pub use raw::Ptr;
|
||||||
|
|
||||||
|
// An even *more* inner module for dealing with TLS. This actually has the
|
||||||
|
// thread local variable and has functions to access the variable.
|
||||||
|
//
|
||||||
|
// Note that this is specially done to fully encapsulate that the accessors
|
||||||
|
// for tls must not be inlined. Wasmtime's async support employs stack
|
||||||
|
// switching which can resume execution on different OS threads. This means
|
||||||
|
// that borrows of our TLS pointer must never live across accesses because
|
||||||
|
// otherwise the access may be split across two threads and cause unsafety.
|
||||||
|
//
|
||||||
|
// This also means that extra care is taken by the runtime to save/restore
|
||||||
|
// these TLS values when the runtime may have crossed threads.
|
||||||
|
mod raw {
|
||||||
|
use super::CallThreadState;
|
||||||
|
use std::cell::Cell;
|
||||||
|
use std::ptr;
|
||||||
|
|
||||||
|
pub type Ptr = *const CallThreadState<'static>;
|
||||||
|
|
||||||
|
thread_local!(static PTR: Cell<Ptr> = Cell::new(ptr::null()));
|
||||||
|
|
||||||
|
#[inline(never)] // see module docs for why this is here
|
||||||
|
pub fn replace(val: Ptr) -> Ptr {
|
||||||
|
PTR.with(|p| p.replace(val))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(never)] // see module docs for why this is here
|
||||||
|
pub fn get() -> Ptr {
|
||||||
|
PTR.with(|p| p.get())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Opaque state used to help control TLS state across stack switches for
|
||||||
|
/// async support.
|
||||||
|
pub struct TlsRestore(raw::Ptr);
|
||||||
|
|
||||||
|
impl TlsRestore {
|
||||||
|
/// Takes the TLS state that is currently configured and returns a
|
||||||
|
/// token that is used to replace it later.
|
||||||
|
///
|
||||||
|
/// This is not a safe operation since it's intended to only be used
|
||||||
|
/// with stack switching found with fibers and async wasmtime.
|
||||||
|
pub unsafe fn take() -> TlsRestore {
|
||||||
|
// Our tls pointer must be set at this time, and it must not be
|
||||||
|
// null. We need to restore the previous pointer since we're
|
||||||
|
// removing ourselves from the call-stack, and in the process we
|
||||||
|
// null out our own previous field for safety in case it's
|
||||||
|
// accidentally used later.
|
||||||
|
let raw = raw::get();
|
||||||
|
assert!(!raw.is_null());
|
||||||
|
let prev = (*raw).prev.replace(ptr::null());
|
||||||
|
raw::replace(prev);
|
||||||
|
TlsRestore(raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Restores a previous tls state back into this thread's TLS.
|
||||||
|
///
|
||||||
|
/// This is unsafe because it's intended to only be used within the
|
||||||
|
/// context of stack switching within wasmtime.
|
||||||
|
pub unsafe fn replace(self) {
|
||||||
|
// We need to configure our previous TLS pointer to whatever is in
|
||||||
|
// TLS at this time, and then we set the current state to ourselves.
|
||||||
|
let prev = raw::get();
|
||||||
|
assert!((*self.0).prev.get().is_null());
|
||||||
|
(*self.0).prev.set(prev);
|
||||||
|
raw::replace(self.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Configures thread local state such that for the duration of the
|
/// Configures thread local state such that for the duration of the
|
||||||
/// execution of `closure` any call to `with` will yield `ptr`, unless this
|
/// execution of `closure` any call to `with` will yield `ptr`, unless this
|
||||||
/// is recursively called again.
|
/// is recursively called again.
|
||||||
pub fn set<R>(ptr: &CallThreadState<'_>, closure: impl FnOnce() -> R) -> R {
|
pub fn set<R>(state: &CallThreadState<'_>, closure: impl FnOnce() -> R) -> R {
|
||||||
struct Reset<'a, T: Copy>(&'a Cell<T>, T);
|
struct Reset<'a, 'b>(&'a CallThreadState<'b>);
|
||||||
|
|
||||||
impl<T: Copy> Drop for Reset<'_, T> {
|
impl Drop for Reset<'_, '_> {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
self.0.set(self.1);
|
raw::replace(self.0.prev.replace(ptr::null()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PTR.with(|p| {
|
// Note that this extension of the lifetime to `'static` should be
|
||||||
// Note that this extension of the lifetime to `'static` should be
|
// safe because we only ever access it below with an anonymous
|
||||||
// safe because we only ever access it below with an anonymous
|
// lifetime, meaning `'static` never leaks out of this module.
|
||||||
// lifetime, meaning `'static` never leaks out of this module.
|
let ptr = unsafe {
|
||||||
let ptr = unsafe {
|
mem::transmute::<*const CallThreadState<'_>, *const CallThreadState<'static>>(state)
|
||||||
mem::transmute::<*const CallThreadState<'_>, *const CallThreadState<'static>>(ptr)
|
};
|
||||||
};
|
let prev = raw::replace(ptr);
|
||||||
let _r = Reset(p, p.replace(ptr));
|
state.prev.set(prev);
|
||||||
closure()
|
let _reset = Reset(state);
|
||||||
})
|
closure()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the last pointer configured with `set` above. Panics if `set`
|
/// Returns the last pointer configured with `set` above. Panics if `set`
|
||||||
/// has not been previously called.
|
/// has not been previously called.
|
||||||
pub fn with<R>(closure: impl FnOnce(Option<&CallThreadState<'_>>) -> R) -> R {
|
pub fn with<R>(closure: impl FnOnce(Option<&CallThreadState<'_>>) -> R) -> R {
|
||||||
PTR.with(|ptr| {
|
let p = raw::get();
|
||||||
let p = ptr.get();
|
unsafe { closure(if p.is_null() { None } else { Some(&*p) }) }
|
||||||
unsafe { closure(if p.is_null() { None } else { Some(&*p) }) }
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -717,7 +717,10 @@ impl Store {
|
|||||||
Poll::Pending => {}
|
Poll::Pending => {}
|
||||||
}
|
}
|
||||||
unsafe {
|
unsafe {
|
||||||
(*suspend).suspend(())?;
|
let before = wasmtime_runtime::TlsRestore::take();
|
||||||
|
let res = (*suspend).suspend(());
|
||||||
|
before.replace();
|
||||||
|
res?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
use std::cell::Cell;
|
use std::cell::Cell;
|
||||||
|
use std::cell::RefCell;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
@@ -493,3 +494,144 @@ fn async_host_func_with_pooling_stacks() {
|
|||||||
run_smoke_test(&func);
|
run_smoke_test(&func);
|
||||||
run_smoke_get0_test(&func);
|
run_smoke_get0_test(&func);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn execute_across_threads<F: Future + 'static>(future: F) {
|
||||||
|
struct UnsafeSend<T>(T);
|
||||||
|
unsafe impl<T> Send for UnsafeSend<T> {}
|
||||||
|
|
||||||
|
impl<T: Future> Future for UnsafeSend<T> {
|
||||||
|
type Output = T::Output;
|
||||||
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T::Output> {
|
||||||
|
unsafe { self.map_unchecked_mut(|p| &mut p.0).poll(cx) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut future = Pin::from(Box::new(UnsafeSend(future)));
|
||||||
|
let poll = future
|
||||||
|
.as_mut()
|
||||||
|
.poll(&mut Context::from_waker(&dummy_waker()));
|
||||||
|
assert!(poll.is_pending());
|
||||||
|
|
||||||
|
std::thread::spawn(move || {
|
||||||
|
let poll = future
|
||||||
|
.as_mut()
|
||||||
|
.poll(&mut Context::from_waker(&dummy_waker()));
|
||||||
|
assert!(!poll.is_pending());
|
||||||
|
})
|
||||||
|
.join()
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resume_separate_thread() {
|
||||||
|
// This test will poll the following future on two threads. Simulating a
|
||||||
|
// trap requires accessing TLS info, so that should be preserved correctly.
|
||||||
|
execute_across_threads(async {
|
||||||
|
let store = async_store();
|
||||||
|
let module = Module::new(
|
||||||
|
store.engine(),
|
||||||
|
"
|
||||||
|
(module
|
||||||
|
(import \"\" \"\" (func))
|
||||||
|
(start 0)
|
||||||
|
)
|
||||||
|
",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let func = Func::wrap0_async(&store, (), |_, _| {
|
||||||
|
Box::new(async {
|
||||||
|
PendingOnce::default().await;
|
||||||
|
Err::<(), _>(wasmtime::Trap::new("test"))
|
||||||
|
})
|
||||||
|
});
|
||||||
|
let result = Instance::new_async(&store, &module, &[func.into()]).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resume_separate_thread2() {
|
||||||
|
// This test will poll the following future on two threads. Catching a
|
||||||
|
// signal requires looking up TLS information to determine whether it's a
|
||||||
|
// trap to handle or not, so that must be preserved correctly across threads.
|
||||||
|
execute_across_threads(async {
|
||||||
|
let store = async_store();
|
||||||
|
let module = Module::new(
|
||||||
|
store.engine(),
|
||||||
|
"
|
||||||
|
(module
|
||||||
|
(import \"\" \"\" (func))
|
||||||
|
(func $start
|
||||||
|
call 0
|
||||||
|
unreachable)
|
||||||
|
(start $start)
|
||||||
|
)
|
||||||
|
",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let func = Func::wrap0_async(&store, (), |_, _| {
|
||||||
|
Box::new(async { PendingOnce::default().await })
|
||||||
|
});
|
||||||
|
let result = Instance::new_async(&store, &module, &[func.into()]).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resume_separate_thread3() {
|
||||||
|
// This test doesn't actually do anything with cross-thread polls, but
|
||||||
|
// instead it deals with scheduling futures at "odd" times.
|
||||||
|
//
|
||||||
|
// First we'll set up a *synchronous* call which will initialize TLS info.
|
||||||
|
// This call is simply to a host-defined function, but it still has the same
|
||||||
|
// "enter into wasm" semantics since it's just calling a trampoline. In this
|
||||||
|
// situation we'll set up the TLS info so it's in place while the body of
|
||||||
|
// the function executes...
|
||||||
|
let store = Store::default();
|
||||||
|
let storage = Rc::new(RefCell::new(None));
|
||||||
|
let storage2 = storage.clone();
|
||||||
|
let f = Func::wrap(&store, move || {
|
||||||
|
// ... and the execution of this host-defined function (while the TLS
|
||||||
|
// info is initialized), will set up a recursive call into wasm. This
|
||||||
|
// recursive call will be done asynchronously so we can suspend it
|
||||||
|
// halfway through.
|
||||||
|
let f = async {
|
||||||
|
let store = async_store();
|
||||||
|
let module = Module::new(
|
||||||
|
store.engine(),
|
||||||
|
"
|
||||||
|
(module
|
||||||
|
(import \"\" \"\" (func))
|
||||||
|
(start 0)
|
||||||
|
)
|
||||||
|
",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let func = Func::wrap0_async(&store, (), |_, _| {
|
||||||
|
Box::new(async { PendingOnce::default().await })
|
||||||
|
});
|
||||||
|
drop(Instance::new_async(&store, &module, &[func.into()]).await);
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
let mut future = Pin::from(Box::new(f));
|
||||||
|
let poll = future
|
||||||
|
.as_mut()
|
||||||
|
.poll(&mut Context::from_waker(&dummy_waker()));
|
||||||
|
assert!(poll.is_pending());
|
||||||
|
|
||||||
|
// ... so at this point our call into wasm is suspended. The call into
|
||||||
|
// wasm will have overwritten TLS info, and we sure hope that the
|
||||||
|
// information is restored at this point. Note that we squirrel away the
|
||||||
|
// future somewhere else to get dropped later. If we were to drop it
|
||||||
|
// here then we would reenter the future's suspended stack to clean it
|
||||||
|
// up, which would do more alterations of TLS information we're not
|
||||||
|
// testing here.
|
||||||
|
*storage2.borrow_mut() = Some(future);
|
||||||
|
|
||||||
|
// ... all in all this function will need access to the original TLS
|
||||||
|
// information to raise the trap. This TLS information should be
|
||||||
|
// restored even though the asynchronous execution is suspended.
|
||||||
|
Err::<(), _>(wasmtime::Trap::new(""))
|
||||||
|
});
|
||||||
|
assert!(f.call(&[]).is_err());
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user