diff --git a/crates/runtime/src/lib.rs b/crates/runtime/src/lib.rs index b53e73c162..1e6cfc8cd2 100644 --- a/crates/runtime/src/lib.rs +++ b/crates/runtime/src/lib.rs @@ -48,7 +48,7 @@ pub use crate::mmap::Mmap; pub use crate::table::{Table, TableElement}; pub use crate::traphandlers::{ catch_traps, init_traps, raise_lib_trap, raise_user_trap, resume_panic, with_last_info, - SignalHandler, Trap, TrapInfo, + SignalHandler, TlsRestore, Trap, TrapInfo, }; pub use crate::vmcontext::{ VMCallerCheckedAnyfunc, VMContext, VMFunctionBody, VMFunctionImport, VMGlobalDefinition, diff --git a/crates/runtime/src/traphandlers.rs b/crates/runtime/src/traphandlers.rs index a23518af9a..0a8d290f06 100644 --- a/crates/runtime/src/traphandlers.rs +++ b/crates/runtime/src/traphandlers.rs @@ -12,6 +12,8 @@ use std::sync::atomic::{AtomicUsize, Ordering::SeqCst}; use std::sync::Once; use wasmtime_environ::ir; +pub use self::tls::TlsRestore; + extern "C" { fn RegisterSetjmp( jmp_buf: *mut *const u8, @@ -491,6 +493,7 @@ pub struct CallThreadState<'a> { jmp_buf: Cell<*const u8>, handling_trap: Cell, trap_info: &'a (dyn TrapInfo + 'a), + prev: Cell, } /// A package of functionality needed by `catch_traps` to figure out what to do @@ -541,6 +544,7 @@ impl<'a> CallThreadState<'a> { jmp_buf: Cell::new(ptr::null()), handling_trap: Cell::new(false), trap_info, + prev: Cell::new(ptr::null()), } } @@ -753,43 +757,108 @@ impl Drop for ResetCell<'_, T> { // the caller to the trap site. mod tls { use super::CallThreadState; - use std::cell::Cell; use std::mem; use std::ptr; - thread_local!(static PTR: Cell<*const CallThreadState<'static>> = Cell::new(ptr::null())); + pub use raw::Ptr; + + // An even *more* inner module for dealing with TLS. This actually has the + // thread local variable and has functions to access the variable. + // + // Note that this is specially done to fully encapsulate that the accessors + // for tls must not be inlined. Wasmtime's async support employs stack + // switching which can resume execution on different OS threads. This means + // that borrows of our TLS pointer must never live across accesses because + // otherwise the access may be split across two threads and cause unsafety. + // + // This also means that extra care is taken by the runtime to save/restore + // these TLS values when the runtime may have crossed threads. + mod raw { + use super::CallThreadState; + use std::cell::Cell; + use std::ptr; + + pub type Ptr = *const CallThreadState<'static>; + + thread_local!(static PTR: Cell = Cell::new(ptr::null())); + + #[inline(never)] // see module docs for why this is here + pub fn replace(val: Ptr) -> Ptr { + PTR.with(|p| p.replace(val)) + } + + #[inline(never)] // see module docs for why this is here + pub fn get() -> Ptr { + PTR.with(|p| p.get()) + } + } + + /// Opaque state used to help control TLS state across stack switches for + /// async support. + pub struct TlsRestore(raw::Ptr); + + impl TlsRestore { + /// Takes the TLS state that is currently configured and returns a + /// token that is used to replace it later. + /// + /// This is not a safe operation since it's intended to only be used + /// with stack switching found with fibers and async wasmtime. + pub unsafe fn take() -> TlsRestore { + // Our tls pointer must be set at this time, and it must not be + // null. We need to restore the previous pointer since we're + // removing ourselves from the call-stack, and in the process we + // null out our own previous field for safety in case it's + // accidentally used later. + let raw = raw::get(); + assert!(!raw.is_null()); + let prev = (*raw).prev.replace(ptr::null()); + raw::replace(prev); + TlsRestore(raw) + } + + /// Restores a previous tls state back into this thread's TLS. + /// + /// This is unsafe because it's intended to only be used within the + /// context of stack switching within wasmtime. + pub unsafe fn replace(self) { + // We need to configure our previous TLS pointer to whatever is in + // TLS at this time, and then we set the current state to ourselves. + let prev = raw::get(); + assert!((*self.0).prev.get().is_null()); + (*self.0).prev.set(prev); + raw::replace(self.0); + } + } /// Configures thread local state such that for the duration of the /// execution of `closure` any call to `with` will yield `ptr`, unless this /// is recursively called again. - pub fn set(ptr: &CallThreadState<'_>, closure: impl FnOnce() -> R) -> R { - struct Reset<'a, T: Copy>(&'a Cell, T); + pub fn set(state: &CallThreadState<'_>, closure: impl FnOnce() -> R) -> R { + struct Reset<'a, 'b>(&'a CallThreadState<'b>); - impl Drop for Reset<'_, T> { + impl Drop for Reset<'_, '_> { fn drop(&mut self) { - self.0.set(self.1); + raw::replace(self.0.prev.replace(ptr::null())); } } - PTR.with(|p| { - // Note that this extension of the lifetime to `'static` should be - // safe because we only ever access it below with an anonymous - // lifetime, meaning `'static` never leaks out of this module. - let ptr = unsafe { - mem::transmute::<*const CallThreadState<'_>, *const CallThreadState<'static>>(ptr) - }; - let _r = Reset(p, p.replace(ptr)); - closure() - }) + // Note that this extension of the lifetime to `'static` should be + // safe because we only ever access it below with an anonymous + // lifetime, meaning `'static` never leaks out of this module. + let ptr = unsafe { + mem::transmute::<*const CallThreadState<'_>, *const CallThreadState<'static>>(state) + }; + let prev = raw::replace(ptr); + state.prev.set(prev); + let _reset = Reset(state); + closure() } /// Returns the last pointer configured with `set` above. Panics if `set` /// has not been previously called. pub fn with(closure: impl FnOnce(Option<&CallThreadState<'_>>) -> R) -> R { - PTR.with(|ptr| { - let p = ptr.get(); - unsafe { closure(if p.is_null() { None } else { Some(&*p) }) } - }) + let p = raw::get(); + unsafe { closure(if p.is_null() { None } else { Some(&*p) }) } } } diff --git a/crates/wasmtime/src/store.rs b/crates/wasmtime/src/store.rs index b5adbe866a..4850c74460 100644 --- a/crates/wasmtime/src/store.rs +++ b/crates/wasmtime/src/store.rs @@ -717,7 +717,10 @@ impl Store { Poll::Pending => {} } unsafe { - (*suspend).suspend(())?; + let before = wasmtime_runtime::TlsRestore::take(); + let res = (*suspend).suspend(()); + before.replace(); + res?; } } } diff --git a/tests/all/async_functions.rs b/tests/all/async_functions.rs index 962a0f113a..ce084f0dca 100644 --- a/tests/all/async_functions.rs +++ b/tests/all/async_functions.rs @@ -1,4 +1,5 @@ use std::cell::Cell; +use std::cell::RefCell; use std::future::Future; use std::pin::Pin; use std::rc::Rc; @@ -493,3 +494,144 @@ fn async_host_func_with_pooling_stacks() { run_smoke_test(&func); run_smoke_get0_test(&func); } + +fn execute_across_threads(future: F) { + struct UnsafeSend(T); + unsafe impl Send for UnsafeSend {} + + impl Future for UnsafeSend { + type Output = T::Output; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + unsafe { self.map_unchecked_mut(|p| &mut p.0).poll(cx) } + } + } + + let mut future = Pin::from(Box::new(UnsafeSend(future))); + let poll = future + .as_mut() + .poll(&mut Context::from_waker(&dummy_waker())); + assert!(poll.is_pending()); + + std::thread::spawn(move || { + let poll = future + .as_mut() + .poll(&mut Context::from_waker(&dummy_waker())); + assert!(!poll.is_pending()); + }) + .join() + .unwrap(); +} + +#[test] +fn resume_separate_thread() { + // This test will poll the following future on two threads. Simulating a + // trap requires accessing TLS info, so that should be preserved correctly. + execute_across_threads(async { + let store = async_store(); + let module = Module::new( + store.engine(), + " + (module + (import \"\" \"\" (func)) + (start 0) + ) + ", + ) + .unwrap(); + let func = Func::wrap0_async(&store, (), |_, _| { + Box::new(async { + PendingOnce::default().await; + Err::<(), _>(wasmtime::Trap::new("test")) + }) + }); + let result = Instance::new_async(&store, &module, &[func.into()]).await; + assert!(result.is_err()); + }); +} + +#[test] +fn resume_separate_thread2() { + // This test will poll the following future on two threads. Catching a + // signal requires looking up TLS information to determine whether it's a + // trap to handle or not, so that must be preserved correctly across threads. + execute_across_threads(async { + let store = async_store(); + let module = Module::new( + store.engine(), + " + (module + (import \"\" \"\" (func)) + (func $start + call 0 + unreachable) + (start $start) + ) + ", + ) + .unwrap(); + let func = Func::wrap0_async(&store, (), |_, _| { + Box::new(async { PendingOnce::default().await }) + }); + let result = Instance::new_async(&store, &module, &[func.into()]).await; + assert!(result.is_err()); + }); +} + +#[test] +fn resume_separate_thread3() { + // This test doesn't actually do anything with cross-thread polls, but + // instead it deals with scheduling futures at "odd" times. + // + // First we'll set up a *synchronous* call which will initialize TLS info. + // This call is simply to a host-defined function, but it still has the same + // "enter into wasm" semantics since it's just calling a trampoline. In this + // situation we'll set up the TLS info so it's in place while the body of + // the function executes... + let store = Store::default(); + let storage = Rc::new(RefCell::new(None)); + let storage2 = storage.clone(); + let f = Func::wrap(&store, move || { + // ... and the execution of this host-defined function (while the TLS + // info is initialized), will set up a recursive call into wasm. This + // recursive call will be done asynchronously so we can suspend it + // halfway through. + let f = async { + let store = async_store(); + let module = Module::new( + store.engine(), + " + (module + (import \"\" \"\" (func)) + (start 0) + ) + ", + ) + .unwrap(); + let func = Func::wrap0_async(&store, (), |_, _| { + Box::new(async { PendingOnce::default().await }) + }); + drop(Instance::new_async(&store, &module, &[func.into()]).await); + unreachable!() + }; + let mut future = Pin::from(Box::new(f)); + let poll = future + .as_mut() + .poll(&mut Context::from_waker(&dummy_waker())); + assert!(poll.is_pending()); + + // ... so at this point our call into wasm is suspended. The call into + // wasm will have overwritten TLS info, and we sure hope that the + // information is restored at this point. Note that we squirrel away the + // future somewhere else to get dropped later. If we were to drop it + // here then we would reenter the future's suspended stack to clean it + // up, which would do more alterations of TLS information we're not + // testing here. + *storage2.borrow_mut() = Some(future); + + // ... all in all this function will need access to the original TLS + // information to raise the trap. This TLS information should be + // restored even though the asynchronous execution is suspended. + Err::<(), _>(wasmtime::Trap::new("")) + }); + assert!(f.call(&[]).is_err()); +}