Wasmtime: fix stack walking across frames from different stores (#4779)

We were previously implicitly assuming that all Wasm frames in a stack used the
same `VMRuntimeLimits` as the previous frame we walked, but this is not true
when Wasm in store A calls into the host which then calls into Wasm in store B:

    | ...             |
    | Host            |  |
    +-----------------+  | stack
    | Wasm in store A |  | grows
    +-----------------+  | down
    | Host            |  |
    +-----------------+  |
    | Wasm in store B |  V
    +-----------------+

Trying to walk this stack would previously result in a runtime panic.

The solution is to push the maintenance of our list of saved Wasm FP/SP/PC
registers that allow us to identify contiguous regions of Wasm frames on the
stack deeper into `CallThreadState`. The saved registers list is now maintained
whenever updating the `CallThreadState` linked list by making the
`CallThreadState::prev` field private and only accessible via a getter and
setter, where the setter always maintains our invariants.
This commit is contained in:
Nick Fitzgerald
2022-08-30 11:28:00 -07:00
committed by GitHub
parent 09c93c70cc
commit ff0e84ecf4
7 changed files with 492 additions and 94 deletions

View File

@@ -7,7 +7,7 @@ use crate::{VMContext, VMRuntimeLimits};
use anyhow::Error;
use std::any::Any;
use std::cell::{Cell, UnsafeCell};
use std::mem::{self, MaybeUninit};
use std::mem::MaybeUninit;
use std::ptr;
use std::sync::Once;
use wasmtime_environ::TrapCode;
@@ -182,19 +182,7 @@ where
{
let limits = (*caller).instance().runtime_limits();
let old_last_wasm_exit_fp = mem::replace(&mut *(**limits).last_wasm_exit_fp.get(), 0);
let old_last_wasm_exit_pc = mem::replace(&mut *(**limits).last_wasm_exit_pc.get(), 0);
let old_last_wasm_entry_sp = mem::replace(&mut *(**limits).last_wasm_entry_sp.get(), 0);
let result = CallThreadState::new(
signal_handler,
capture_backtrace,
old_last_wasm_exit_fp,
old_last_wasm_exit_pc,
old_last_wasm_entry_sp,
*limits,
)
.with(|cx| {
let result = CallThreadState::new(signal_handler, capture_backtrace, *limits).with(|cx| {
wasmtime_setjmp(
cx.jmp_buf.as_ptr(),
call_closure::<F>,
@@ -203,10 +191,6 @@ where
)
});
*(**limits).last_wasm_exit_fp.get() = old_last_wasm_exit_fp;
*(**limits).last_wasm_exit_pc.get() = old_last_wasm_exit_pc;
*(**limits).last_wasm_entry_sp.get() = old_last_wasm_entry_sp;
return match result {
Ok(x) => Ok(x),
Err((UnwindReason::Trap(reason), backtrace)) => Err(Box::new(Trap { reason, backtrace })),
@@ -221,20 +205,159 @@ where
}
}
/// Temporary state stored on the stack which is registered in the `tls` module
/// below for calls into wasm.
pub struct CallThreadState {
unwind: UnsafeCell<MaybeUninit<(UnwindReason, Option<Backtrace>)>>,
jmp_buf: Cell<*const u8>,
handling_trap: Cell<bool>,
signal_handler: Option<*const SignalHandler<'static>>,
prev: Cell<tls::Ptr>,
capture_backtrace: bool,
pub(crate) old_last_wasm_exit_fp: usize,
pub(crate) old_last_wasm_exit_pc: usize,
pub(crate) old_last_wasm_entry_sp: usize,
pub(crate) limits: *const VMRuntimeLimits,
// Module to hide visibility of the `CallThreadState::prev` field and force
// usage of its accessor methods.
mod call_thread_state {
use super::*;
use std::mem;
/// Temporary state stored on the stack which is registered in the `tls` module
/// below for calls into wasm.
pub struct CallThreadState {
pub(super) unwind: UnsafeCell<MaybeUninit<(UnwindReason, Option<Backtrace>)>>,
pub(super) jmp_buf: Cell<*const u8>,
pub(super) handling_trap: Cell<bool>,
pub(super) signal_handler: Option<*const SignalHandler<'static>>,
pub(super) capture_backtrace: bool,
pub(crate) limits: *const VMRuntimeLimits,
prev: Cell<tls::Ptr>,
// The values of `VMRuntimeLimits::last_wasm_{exit_{pc,fp},entry_sp}` for
// the *previous* `CallThreadState`. Our *current* last wasm PC/FP/SP are
// saved in `self.limits`. We save a copy of the old registers here because
// the `VMRuntimeLimits` typically doesn't change across nested calls into
// Wasm (i.e. they are typically calls back into the same store and
// `self.limits == self.prev.limits`) and we must to maintain the list of
// contiguous-Wasm-frames stack regions for backtracing purposes.
old_last_wasm_exit_fp: Cell<usize>,
old_last_wasm_exit_pc: Cell<usize>,
old_last_wasm_entry_sp: Cell<usize>,
}
impl CallThreadState {
#[inline]
pub(super) fn new(
signal_handler: Option<*const SignalHandler<'static>>,
capture_backtrace: bool,
limits: *const VMRuntimeLimits,
) -> CallThreadState {
CallThreadState {
unwind: UnsafeCell::new(MaybeUninit::uninit()),
jmp_buf: Cell::new(ptr::null()),
handling_trap: Cell::new(false),
signal_handler,
capture_backtrace,
limits,
prev: Cell::new(ptr::null()),
old_last_wasm_exit_fp: Cell::new(0),
old_last_wasm_exit_pc: Cell::new(0),
old_last_wasm_entry_sp: Cell::new(0),
}
}
/// Get the saved FP upon exit from Wasm for the previous `CallThreadState`.
pub fn old_last_wasm_exit_fp(&self) -> usize {
self.old_last_wasm_exit_fp.get()
}
/// Get the saved PC upon exit from Wasm for the previous `CallThreadState`.
pub fn old_last_wasm_exit_pc(&self) -> usize {
self.old_last_wasm_exit_pc.get()
}
/// Get the saved SP upon entry into Wasm for the previous `CallThreadState`.
pub fn old_last_wasm_entry_sp(&self) -> usize {
self.old_last_wasm_entry_sp.get()
}
/// Get the previous `CallThreadState`.
pub fn prev(&self) -> tls::Ptr {
self.prev.get()
}
/// Connect the link to the previous `CallThreadState`.
///
/// Synchronizes the last wasm FP, PC, and SP on `self` and the old
/// `self.prev` for the given new `prev`, and returns the old
/// `self.prev`.
pub unsafe fn set_prev(&self, prev: tls::Ptr) -> tls::Ptr {
let old_prev = self.prev.get();
// Restore the old `prev`'s saved registers in its
// `VMRuntimeLimits`. This is necessary for when we are async
// suspending the top `CallThreadState` and doing `set_prev(null)`
// on it, and so any stack walking we do subsequently will start at
// the old `prev` and look at its `VMRuntimeLimits` to get the
// initial saved registers.
if let Some(old_prev) = old_prev.as_ref() {
*(*old_prev.limits).last_wasm_exit_fp.get() = self.old_last_wasm_exit_fp();
*(*old_prev.limits).last_wasm_exit_pc.get() = self.old_last_wasm_exit_pc();
*(*old_prev.limits).last_wasm_entry_sp.get() = self.old_last_wasm_entry_sp();
}
self.prev.set(prev);
let mut old_last_wasm_exit_fp = 0;
let mut old_last_wasm_exit_pc = 0;
let mut old_last_wasm_entry_sp = 0;
if let Some(prev) = prev.as_ref() {
// We are entering a new `CallThreadState` or resuming a
// previously suspended one. This means we will push new Wasm
// frames that save the new Wasm FP/SP/PC registers into
// `VMRuntimeLimits`, we need to first save the old Wasm
// FP/SP/PC registers into this new `CallThreadState` to
// maintain our list of contiguous Wasm frame regions that we
// use when capturing stack traces.
//
// NB: the Wasm<--->host trampolines saved the Wasm FP/SP/PC
// registers in the active-at-that-time store's
// `VMRuntimeLimits`. For the most recent FP/PC/SP that is the
// `state.prev.limits` (since we haven't entered this
// `CallThreadState` yet). And that can be a different
// `VMRuntimeLimits` instance from the currently active
// `state.limits`, which will be used by the upcoming call into
// Wasm! Consider the case where we have multiple, nested calls
// across stores (with host code in between, by necessity, since
// only things in the same store can be linked directly
// together):
//
// | ... |
// | Host | |
// +-----------------+ | stack
// | Wasm in store A | | grows
// +-----------------+ | down
// | Host | |
// +-----------------+ |
// | Wasm in store B | V
// +-----------------+
//
// In this scenario `state.limits != state.prev.limits`,
// i.e. `B.limits != A.limits`! Therefore we must take care to
// read the old FP/SP/PC from `state.prev.limits`, rather than
// `state.limits`, and store those saved registers into the
// current `state`.
//
// See also the comment above the
// `CallThreadState::old_last_wasm_*` fields.
old_last_wasm_exit_fp =
mem::replace(&mut *(*prev.limits).last_wasm_exit_fp.get(), 0);
old_last_wasm_exit_pc =
mem::replace(&mut *(*prev.limits).last_wasm_exit_pc.get(), 0);
old_last_wasm_entry_sp =
mem::replace(&mut *(*prev.limits).last_wasm_entry_sp.get(), 0);
}
self.old_last_wasm_exit_fp.set(old_last_wasm_exit_fp);
self.old_last_wasm_exit_pc.set(old_last_wasm_exit_pc);
self.old_last_wasm_entry_sp.set(old_last_wasm_entry_sp);
old_prev
}
}
}
pub use call_thread_state::*;
enum UnwindReason {
Panic(Box<dyn Any + Send>),
@@ -242,34 +365,11 @@ enum UnwindReason {
}
impl CallThreadState {
#[inline]
fn new(
signal_handler: Option<*const SignalHandler<'static>>,
capture_backtrace: bool,
old_last_wasm_exit_fp: usize,
old_last_wasm_exit_pc: usize,
old_last_wasm_entry_sp: usize,
limits: *const VMRuntimeLimits,
) -> CallThreadState {
CallThreadState {
unwind: UnsafeCell::new(MaybeUninit::uninit()),
jmp_buf: Cell::new(ptr::null()),
handling_trap: Cell::new(false),
signal_handler,
prev: Cell::new(ptr::null()),
capture_backtrace,
old_last_wasm_exit_fp,
old_last_wasm_exit_pc,
old_last_wasm_entry_sp,
limits,
}
}
fn with(
self,
mut self,
closure: impl FnOnce(&CallThreadState) -> i32,
) -> Result<(), (UnwindReason, Option<Backtrace>)> {
let ret = tls::set(&self, || closure(&self));
let ret = tls::set(&mut self, |me| closure(me));
if ret != 0 {
Ok(())
} else {
@@ -366,7 +466,7 @@ impl CallThreadState {
let mut state = Some(self);
std::iter::from_fn(move || {
let this = state?;
state = unsafe { this.prev.get().as_ref() };
state = unsafe { this.prev().as_ref() };
Some(this)
})
}
@@ -462,7 +562,9 @@ mod tls {
/// Opaque state used to help control TLS state across stack switches for
/// async support.
pub struct TlsRestore(raw::Ptr);
pub struct TlsRestore {
state: raw::Ptr,
}
impl TlsRestore {
/// Takes the TLS state that is currently configured and returns a
@@ -476,14 +578,16 @@ mod tls {
// removing ourselves from the call-stack, and in the process we
// null out our own previous field for safety in case it's
// accidentally used later.
let raw = raw::get();
if !raw.is_null() {
let prev = (*raw).prev.replace(ptr::null());
raw::replace(prev);
let state = raw::get();
if let Some(state) = state.as_ref() {
let prev_state = state.set_prev(ptr::null());
raw::replace(prev_state);
} else {
// Null case: we aren't in a wasm context, so theres no tls to
// save for restoration.
}
// Null case: we aren't in a wasm context, so theres no tls
// to save for restoration.
TlsRestore(raw)
TlsRestore { state }
}
/// Restores a previous tls state back into this thread's TLS.
@@ -493,40 +597,50 @@ mod tls {
pub unsafe fn replace(self) {
// Null case: we aren't in a wasm context, so theres no tls
// to restore.
if self.0.is_null() {
if self.state.is_null() {
return;
}
// We need to configure our previous TLS pointer to whatever is in
// TLS at this time, and then we set the current state to ourselves.
let prev = raw::get();
assert!((*self.0).prev.get().is_null());
(*self.0).prev.set(prev);
raw::replace(self.0);
assert!((*self.state).prev().is_null());
(*self.state).set_prev(prev);
raw::replace(self.state);
}
}
/// Configures thread local state such that for the duration of the
/// execution of `closure` any call to `with` will yield `ptr`, unless this
/// is recursively called again.
/// execution of `closure` any call to `with` will yield `state`, unless
/// this is recursively called again.
#[inline]
pub fn set<R>(state: &CallThreadState, closure: impl FnOnce() -> R) -> R {
struct Reset<'a>(&'a CallThreadState);
pub fn set<R>(state: &mut CallThreadState, closure: impl FnOnce(&CallThreadState) -> R) -> R {
struct Reset<'a> {
state: &'a CallThreadState,
}
impl Drop for Reset<'_> {
#[inline]
fn drop(&mut self) {
raw::replace(self.0.prev.replace(ptr::null()));
unsafe {
let prev = self.state.set_prev(ptr::null());
let old_state = raw::replace(prev);
debug_assert!(std::ptr::eq(old_state, self.state));
}
}
}
let prev = raw::replace(state);
state.prev.set(prev);
let _reset = Reset(state);
closure()
unsafe {
state.set_prev(prev);
let reset = Reset { state };
closure(reset.state)
}
}
/// Returns the last pointer configured with `set` above. Panics if `set`
/// has not been previously called.
/// Returns the last pointer configured with `set` above, if any.
pub fn with<R>(closure: impl FnOnce(Option<&CallThreadState>) -> R) -> R {
let p = raw::get();
unsafe { closure(if p.is_null() { None } else { Some(&*p) }) }

View File

@@ -149,18 +149,18 @@ impl Backtrace {
// trace through (since each `CallTheadState` saves the *previous*
// call into Wasm's saved registers, and the youngest call into
// Wasm's registers are saved in the `VMRuntimeLimits`)
if state.prev.get().is_null() {
debug_assert_eq!(state.old_last_wasm_exit_pc, 0);
debug_assert_eq!(state.old_last_wasm_exit_fp, 0);
debug_assert_eq!(state.old_last_wasm_entry_sp, 0);
if state.prev().is_null() {
debug_assert_eq!(state.old_last_wasm_exit_pc(), 0);
debug_assert_eq!(state.old_last_wasm_exit_fp(), 0);
debug_assert_eq!(state.old_last_wasm_entry_sp(), 0);
log::trace!("====== Done Capturing Backtrace ======");
return;
}
if let ControlFlow::Break(()) = Self::trace_through_wasm(
state.old_last_wasm_exit_pc,
state.old_last_wasm_exit_fp,
state.old_last_wasm_entry_sp,
state.old_last_wasm_exit_pc(),
state.old_last_wasm_exit_fp(),
state.old_last_wasm_entry_sp(),
&mut f,
) {
log::trace!("====== Done Capturing Backtrace ======");
@@ -266,7 +266,7 @@ impl Backtrace {
}
/// Iterate over the frames inside this backtrace.
pub fn frames<'a>(&'a self) -> impl Iterator<Item = &'a Frame> + 'a {
pub fn frames<'a>(&'a self) -> impl ExactSizeIterator<Item = &'a Frame> + 'a {
self.0.iter()
}
}