diff --git a/crates/api/src/callable.rs b/crates/api/src/callable.rs index 1fbaa3789b..b13fe6f589 100644 --- a/crates/api/src/callable.rs +++ b/crates/api/src/callable.rs @@ -1,5 +1,5 @@ use crate::runtime::Store; -use crate::trampoline::{generate_func_export, take_api_trap}; +use crate::trampoline::generate_func_export; use crate::trap::Trap; use crate::types::FuncType; use crate::values::Val; @@ -157,8 +157,7 @@ impl WrappedCallable for WasmtimeFn { ) }) } { - let trap = take_api_trap().unwrap_or_else(|| Trap::from_jit(error)); - return Err(trap); + return Err(Trap::from_jit(error)); } // Load the return values out of `values_vec`. diff --git a/crates/api/src/instance.rs b/crates/api/src/instance.rs index 292d49b85a..6b04b0d806 100644 --- a/crates/api/src/instance.rs +++ b/crates/api/src/instance.rs @@ -1,7 +1,6 @@ use crate::externals::Extern; use crate::module::Module; use crate::runtime::Store; -use crate::trampoline::take_api_trap; use crate::trap::Trap; use crate::types::{ExportType, ExternType}; use anyhow::{Error, Result}; @@ -29,12 +28,9 @@ fn instantiate( let instance = compiled_module .instantiate(&mut resolver) .map_err(|e| -> Error { - if let Some(trap) = take_api_trap() { - trap.into() - } else if let InstantiationError::StartTrap(trap) = e { - Trap::from_jit(trap).into() - } else { - e.into() + match e { + InstantiationError::StartTrap(trap) => Trap::from_jit(trap).into(), + other => other.into(), } })?; Ok(instance) diff --git a/crates/api/src/trampoline/func.rs b/crates/api/src/trampoline/func.rs index 29726cc582..d6435c8d79 100644 --- a/crates/api/src/trampoline/func.rs +++ b/crates/api/src/trampoline/func.rs @@ -1,11 +1,12 @@ //! Support for a calling of an imported function. use super::create_handle::create_handle; -use super::trap::{record_api_trap, TrapSink, API_TRAP_CODE}; -use crate::{Callable, FuncType, Store, Val}; +use super::trap::TrapSink; +use crate::{Callable, FuncType, Store, Trap, Val}; use anyhow::{bail, Result}; use std::cmp; use std::convert::TryFrom; +use std::panic::{self, AssertUnwindSafe}; use std::rc::Rc; use wasmtime_environ::entity::{EntityRef, PrimaryMap}; use wasmtime_environ::ir::types; @@ -69,42 +70,70 @@ unsafe extern "C" fn stub_fn( _caller_vmctx: *mut VMContext, call_id: u32, values_vec: *mut i128, -) -> u32 { - let instance = InstanceHandle::from_vmctx(vmctx); +) { + // Here we are careful to use `catch_unwind` to ensure Rust panics don't + // unwind past us. The primary reason for this is that Rust considers it UB + // to unwind past an `extern "C"` function. Here we are in an `extern "C"` + // function and the cross into wasm was through an `extern "C"` function at + // the base of the stack as well. We'll need to wait for assorted RFCs and + // language features to enable this to be done in a sound and stable fashion + // before avoiding catching the panic here. + // + // Also note that there are intentionally no local variables on this stack + // frame. The reason for that is that some of the "raise" functions we have + // below will trigger a longjmp, which won't run local destructors if we + // have any. To prevent leaks we avoid having any local destructors by + // avoiding local variables. + let result = panic::catch_unwind(AssertUnwindSafe(|| call_stub(vmctx, call_id, values_vec))); - let (args, returns_len) = { - let module = instance.module_ref(); - let signature = &module.signatures[module.functions[FuncIndex::new(call_id as usize)]]; + match result { + Ok(Ok(())) => {} - let mut args = Vec::new(); - for i in 2..signature.params.len() { - args.push(Val::read_value_from( - values_vec.offset(i as isize - 2), - signature.params[i].value_type, - )) - } - (args, signature.returns.len()) - }; + // If a trap was raised (an error returned from the imported function) + // then we smuggle the trap through `Box` through to the + // call-site, which gets unwrapped in `Trap::from_jit` later on as we + // convert from the internal `Trap` type to our own `Trap` type in this + // crate. + Ok(Err(trap)) => wasmtime_runtime::raise_user_trap(Box::new(trap)), - let mut returns = vec![Val::null(); returns_len]; - let func = &instance - .host_state() - .downcast_ref::() - .expect("state") - .func; + // And finally if the imported function panicked, then we trigger the + // form of unwinding that's safe to jump over wasm code on all + // platforms. + Err(panic) => wasmtime_runtime::resume_panic(panic), + } - match func.call(&args, &mut returns) { - Ok(()) => { - for (i, r#return) in returns.iter_mut().enumerate() { - // TODO check signature.returns[i].value_type ? - r#return.write_value_to(values_vec.add(i)); + unsafe fn call_stub( + vmctx: *mut VMContext, + call_id: u32, + values_vec: *mut i128, + ) -> Result<(), Trap> { + let instance = InstanceHandle::from_vmctx(vmctx); + + let (args, returns_len) = { + let module = instance.module_ref(); + let signature = &module.signatures[module.functions[FuncIndex::new(call_id as usize)]]; + + let mut args = Vec::new(); + for i in 2..signature.params.len() { + args.push(Val::read_value_from( + values_vec.offset(i as isize - 2), + signature.params[i].value_type, + )) } - 0 - } - Err(trap) => { - record_api_trap(trap); - 1 + (args, signature.returns.len()) + }; + + let mut returns = vec![Val::null(); returns_len]; + let state = &instance + .host_state() + .downcast_ref::() + .expect("state"); + state.func.call(&args, &mut returns)?; + for (i, ret) in returns.iter_mut().enumerate() { + // TODO check signature.returns[i].value_type ? + ret.write_value_to(values_vec.add(i)); } + Ok(()) } } @@ -136,9 +165,6 @@ fn make_trampoline( // Add the `values_vec` parameter. stub_sig.params.push(ir::AbiParam::new(pointer_type)); - // Add error/trap return. - stub_sig.returns.push(ir::AbiParam::new(types::I32)); - // Compute the size of the values vector. The vmctx and caller vmctx are passed separately. let value_size = 16; let values_vec_len = ((value_size as usize) @@ -195,13 +221,10 @@ fn make_trampoline( let callee_value = builder .ins() .iconst(pointer_type, stub_fn as *const VMFunctionBody as i64); - let call = builder + builder .ins() .call_indirect(new_sig, callee_value, &callee_args); - let call_result = builder.func.dfg.inst_results(call)[0]; - builder.ins().trapnz(call_result, API_TRAP_CODE); - let mflags = MemFlags::trusted(); let mut results = Vec::new(); for (i, r) in signature.returns.iter().enumerate() { diff --git a/crates/api/src/trampoline/mod.rs b/crates/api/src/trampoline/mod.rs index 8b44415361..ffb224a232 100644 --- a/crates/api/src/trampoline/mod.rs +++ b/crates/api/src/trampoline/mod.rs @@ -16,7 +16,6 @@ use anyhow::Result; use std::rc::Rc; pub use self::global::GlobalState; -pub use self::trap::take_api_trap; pub fn generate_func_export( ft: &FuncType, diff --git a/crates/api/src/trampoline/trap.rs b/crates/api/src/trampoline/trap.rs index be92126ed9..17424e6393 100644 --- a/crates/api/src/trampoline/trap.rs +++ b/crates/api/src/trampoline/trap.rs @@ -1,32 +1,7 @@ -use std::cell::Cell; - -use crate::Trap; use wasmtime_environ::ir::{SourceLoc, TrapCode}; use wasmtime_environ::TrapInformation; use wasmtime_jit::trampoline::binemit; -// Randomly selected user TrapCode magic number 13. -pub const API_TRAP_CODE: TrapCode = TrapCode::User(13); - -thread_local! { - static RECORDED_API_TRAP: Cell> = Cell::new(None); -} - -pub fn record_api_trap(trap: Trap) { - RECORDED_API_TRAP.with(|data| { - let trap = Cell::new(Some(trap)); - data.swap(&trap); - assert!( - trap.take().is_none(), - "Only one API trap per thread can be recorded at a moment!" - ); - }); -} - -pub fn take_api_trap() -> Option { - RECORDED_API_TRAP.with(|data| data.take()) -} - pub(crate) struct TrapSink { pub traps: Vec, } diff --git a/crates/api/src/trap.rs b/crates/api/src/trap.rs index 077145acb0..019939fab3 100644 --- a/crates/api/src/trap.rs +++ b/crates/api/src/trap.rs @@ -33,7 +33,24 @@ impl Trap { } pub(crate) fn from_jit(jit: wasmtime_runtime::Trap) -> Self { - Trap::new_with_trace(jit.to_string(), jit.backtrace) + match jit { + wasmtime_runtime::Trap::User(error) => { + // Since we're the only one using the wasmtime internals (in + // theory) we should only see user errors which were originally + // created from our own `Trap` type (see the trampoline module + // with functions). + // + // If this unwrap trips for someone we'll need to tweak the + // return type of this function to probably be `anyhow::Error` + // or something like that. + *error + .downcast() + .expect("only `Trap` user errors are supported") + } + wasmtime_runtime::Trap::Wasm { desc, backtrace } => { + Trap::new_with_trace(desc.to_string(), backtrace) + } + } } fn new_with_trace(message: String, native_trace: Backtrace) -> Self { diff --git a/crates/api/tests/traps.rs b/crates/api/tests/traps.rs index 9c89e1aa51..c5b221ff28 100644 --- a/crates/api/tests/traps.rs +++ b/crates/api/tests/traps.rs @@ -1,4 +1,5 @@ use anyhow::Result; +use std::panic::{self, AssertUnwindSafe}; use std::rc::Rc; use wasmtime::*; @@ -215,3 +216,95 @@ wasm backtrace: ); Ok(()) } + +#[test] +fn trap_start_function_import() -> Result<()> { + struct ReturnTrap; + + impl Callable for ReturnTrap { + fn call(&self, _params: &[Val], _results: &mut [Val]) -> Result<(), Trap> { + Err(Trap::new("user trap")) + } + } + + let store = Store::default(); + let binary = wat::parse_str( + r#" + (module $a + (import "" "" (func $foo)) + (start $foo) + ) + "#, + )?; + + let module = Module::new(&store, &binary)?; + let sig = FuncType::new(Box::new([]), Box::new([])); + let func = Func::new(&store, sig, Rc::new(ReturnTrap)); + let err = Instance::new(&module, &[func.into()]).err().unwrap(); + assert_eq!(err.downcast_ref::().unwrap().message(), "user trap"); + Ok(()) +} + +#[test] +fn rust_panic_import() -> Result<()> { + struct Panic; + + impl Callable for Panic { + fn call(&self, _params: &[Val], _results: &mut [Val]) -> Result<(), Trap> { + panic!("this is a panic"); + } + } + + let store = Store::default(); + let binary = wat::parse_str( + r#" + (module $a + (import "" "" (func $foo)) + (func (export "foo") call $foo) + ) + "#, + )?; + + let module = Module::new(&store, &binary)?; + let sig = FuncType::new(Box::new([]), Box::new([])); + let func = Func::new(&store, sig, Rc::new(Panic)); + let instance = Instance::new(&module, &[func.into()])?; + let func = instance.exports()[0].func().unwrap().clone(); + let err = panic::catch_unwind(AssertUnwindSafe(|| { + drop(func.call(&[])); + })) + .unwrap_err(); + assert_eq!(err.downcast_ref::<&'static str>(), Some(&"this is a panic")); + Ok(()) +} + +#[test] +fn rust_panic_start_function() -> Result<()> { + struct Panic; + + impl Callable for Panic { + fn call(&self, _params: &[Val], _results: &mut [Val]) -> Result<(), Trap> { + panic!("this is a panic"); + } + } + + let store = Store::default(); + let binary = wat::parse_str( + r#" + (module $a + (import "" "" (func $foo)) + (start $foo) + ) + "#, + )?; + + let module = Module::new(&store, &binary)?; + let sig = FuncType::new(Box::new([]), Box::new([])); + let func = Func::new(&store, sig, Rc::new(Panic)); + let err = panic::catch_unwind(AssertUnwindSafe(|| { + drop(Instance::new(&module, &[func.into()])); + })) + .unwrap_err(); + assert_eq!(err.downcast_ref::<&'static str>(), Some(&"this is a panic")); + Ok(()) +} diff --git a/crates/c-api/src/lib.rs b/crates/c-api/src/lib.rs index 2ffdf03cc1..1fc5be8de1 100644 --- a/crates/c-api/src/lib.rs +++ b/crates/c-api/src/lib.rs @@ -6,6 +6,7 @@ // TODO complete the C API use std::cell::RefCell; +use std::panic::{self, AssertUnwindSafe}; use std::rc::Rc; use std::{mem, ptr, slice}; use wasmtime::{ @@ -488,15 +489,34 @@ pub unsafe extern "C" fn wasm_func_call( let val = &(*args.add(i)); params.push(val.val()); } - match func.call(¶ms) { - Ok(out) => { + + // We're calling arbitrary code here most of the time, and we in general + // want to try to insulate callers against bugs in wasmtime/wasi/etc if we + // can. As a result we catch panics here and transform them to traps to + // allow the caller to have any insulation possible against Rust panics. + let result = panic::catch_unwind(AssertUnwindSafe(|| func.call(¶ms))); + match result { + Ok(Ok(out)) => { for i in 0..func.result_arity() { let val = &mut (*results.add(i)); *val = wasm_val_t::from_val(&out[i]); } ptr::null_mut() } - Err(trap) => { + Ok(Err(trap)) => { + let trap = Box::new(wasm_trap_t { + trap: HostRef::new(trap), + }); + Box::into_raw(trap) + } + Err(panic) => { + let trap = if let Some(msg) = panic.downcast_ref::() { + Trap::new(msg) + } else if let Some(msg) = panic.downcast_ref::<&'static str>() { + Trap::new(*msg) + } else { + Trap::new("rust panic happened") + }; let trap = Box::new(wasm_trap_t { trap: HostRef::new(trap), }); diff --git a/crates/runtime/src/lib.rs b/crates/runtime/src/lib.rs index 4be4c439c0..a8109b86e5 100644 --- a/crates/runtime/src/lib.rs +++ b/crates/runtime/src/lib.rs @@ -44,7 +44,8 @@ pub use crate::jit_int::GdbJitImageRegistration; pub use crate::mmap::Mmap; pub use crate::sig_registry::SignatureRegistry; pub use crate::trap_registry::{get_mut_trap_registry, get_trap_registry, TrapRegistrationGuard}; -pub use crate::traphandlers::{wasmtime_call, wasmtime_call_trampoline, Trap}; +pub use crate::traphandlers::resume_panic; +pub use crate::traphandlers::{raise_user_trap, wasmtime_call, wasmtime_call_trampoline, Trap}; pub use crate::vmcontext::{ VMCallerCheckedAnyfunc, VMContext, VMFunctionBody, VMFunctionImport, VMGlobalDefinition, VMGlobalImport, VMInvokeArgument, VMMemoryDefinition, VMMemoryImport, VMSharedSignatureIndex, diff --git a/crates/runtime/src/trap_registry.rs b/crates/runtime/src/trap_registry.rs index 5afcc16837..f1c4373a2b 100644 --- a/crates/runtime/src/trap_registry.rs +++ b/crates/runtime/src/trap_registry.rs @@ -1,5 +1,6 @@ use lazy_static::lazy_static; use std::collections::HashMap; +use std::fmt; use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use wasmtime_environ::ir; @@ -22,6 +23,35 @@ pub struct TrapDescription { pub trap_code: ir::TrapCode, } +impl fmt::Display for TrapDescription { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "wasm trap: {}, source location: {}", + trap_code_to_expected_string(self.trap_code), + self.source_loc + ) + } +} + +fn trap_code_to_expected_string(trap_code: ir::TrapCode) -> String { + use ir::TrapCode::*; + match trap_code { + StackOverflow => "call stack exhausted".to_string(), + HeapOutOfBounds => "out of bounds memory access".to_string(), + TableOutOfBounds => "undefined element".to_string(), + OutOfBounds => "out of bounds".to_string(), // Note: not covered by the test suite + IndirectCallToNull => "uninitialized element".to_string(), + BadSignature => "indirect call type mismatch".to_string(), + IntegerOverflow => "integer overflow".to_string(), + IntegerDivisionByZero => "integer divide by zero".to_string(), + BadConversionToInteger => "invalid conversion to integer".to_string(), + UnreachableCodeReached => "unreachable".to_string(), + Interrupt => "interrupt".to_string(), // Note: not covered by the test suite + User(x) => format!("user trap {}", x), // Note: not covered by the test suite + } +} + /// RAII guard for deregistering traps pub struct TrapRegistrationGuard(usize); diff --git a/crates/runtime/src/traphandlers.rs b/crates/runtime/src/traphandlers.rs index bc05e1cbf1..2c532854c9 100644 --- a/crates/runtime/src/traphandlers.rs +++ b/crates/runtime/src/traphandlers.rs @@ -5,7 +5,9 @@ use crate::trap_registry::get_trap_registry; use crate::trap_registry::TrapDescription; use crate::vmcontext::{VMContext, VMFunctionBody}; use backtrace::Backtrace; +use std::any::Any; use std::cell::Cell; +use std::error::Error; use std::fmt; use std::ptr; use wasmtime_environ::ir; @@ -24,6 +26,7 @@ extern "C" { caller_vmctx: *mut u8, callee: *const VMFunctionBody, ) -> i32; + fn Unwind(jmp_buf: *const u8) -> !; } /// Record the Trap code and wasm bytecode offset in TLS somewhere @@ -44,7 +47,7 @@ pub extern "C" fn RecordTrap(pc: *const u8, reset_guard_page: bool) -> *const u8 } let registry = get_trap_registry(); - let trap = Trap { + let trap = Trap::Wasm { desc: registry .get_trap(pc as usize) .unwrap_or_else(|| TrapDescription { @@ -58,16 +61,38 @@ pub extern "C" fn RecordTrap(pc: *const u8, reset_guard_page: bool) -> *const u8 info.reset_guard_page.set(true); } - let prev = info.trap.replace(Some(trap)); - assert!( - prev.is_none(), - "Only one trap per thread can be recorded at a moment!" - ); - + info.unwind.replace(UnwindReason::Trap(trap)); info.jmp_buf.get() }) } +/// Raises a user-defined trap immediately. +/// +/// This function performs as-if a wasm trap was just executed, only the trap +/// has a dynamic payload associated with it which is user-provided. This trap +/// payload is then returned from `wasmtime_call` an `wasmtime_call_trampoline` +/// below. +/// +/// # Safety +/// +/// Only safe to call when wasm code is on the stack, aka `wasmtime_call` or +/// `wasmtime_call_trampoline` must have been previously called. +pub unsafe fn raise_user_trap(data: Box) -> ! { + let trap = Trap::User(data); + tls::with(|info| info.unwind_with(UnwindReason::Trap(trap))) +} + +/// Carries a Rust panic across wasm code and resumes the panic on the other +/// side. +/// +/// # Safety +/// +/// Only safe to call when wasm code is on the stack, aka `wasmtime_call` or +/// `wasmtime_call_trampoline` must have been previously called. +pub unsafe fn resume_panic(payload: Box) -> ! { + tls::with(|info| info.unwind_with(UnwindReason::Panic(payload))) +} + #[cfg(target_os = "windows")] fn reset_guard_page() { extern "C" { @@ -86,45 +111,30 @@ fn reset_guard_page() {} /// Stores trace message with backtrace. #[derive(Debug)] -pub struct Trap { - /// What sort of trap happened, as well as where in the original wasm module - /// it happened. - pub desc: TrapDescription, - /// Native stack backtrace at the time the trap occurred - pub backtrace: Backtrace, +pub enum Trap { + /// A user-raised trap through `raise_user_trap`. + User(Box), + /// A wasm-originating trap from wasm code itself. + Wasm { + /// What sort of trap happened, as well as where in the original wasm module + /// it happened. + desc: TrapDescription, + /// Native stack backtrace at the time the trap occurred + backtrace: Backtrace, + }, } impl fmt::Display for Trap { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "wasm trap: {}, source location: {}", - trap_code_to_expected_string(self.desc.trap_code), - self.desc.source_loc - ) + match self { + Trap::User(user) => user.fmt(f), + Trap::Wasm { desc, .. } => desc.fmt(f), + } } } impl std::error::Error for Trap {} -fn trap_code_to_expected_string(trap_code: ir::TrapCode) -> String { - use ir::TrapCode::*; - match trap_code { - StackOverflow => "call stack exhausted".to_string(), - HeapOutOfBounds => "out of bounds memory access".to_string(), - TableOutOfBounds => "undefined element".to_string(), - OutOfBounds => "out of bounds".to_string(), // Note: not covered by the test suite - IndirectCallToNull => "uninitialized element".to_string(), - BadSignature => "indirect call type mismatch".to_string(), - IntegerOverflow => "integer overflow".to_string(), - IntegerDivisionByZero => "integer divide by zero".to_string(), - BadConversionToInteger => "invalid conversion to integer".to_string(), - UnreachableCodeReached => "unreachable".to_string(), - Interrupt => "interrupt".to_string(), // Note: not covered by the test suite - User(x) => format!("user trap {}", x), // Note: not covered by the test suite - } -} - /// Call the wasm function pointed to by `callee`. `values_vec` points to /// a buffer which holds the incoming arguments, and to which the outgoing /// return values will be written. @@ -145,12 +155,7 @@ pub unsafe extern "C" fn wasmtime_call_trampoline( values_vec, ) }); - - if ret == 0 { - Err(cx.unwrap_trap()) - } else { - Ok(()) - } + cx.into_result(ret) } /// Call the wasm function pointed to by `callee`, which has no arguments or @@ -170,34 +175,54 @@ pub unsafe extern "C" fn wasmtime_call( callee, ) }); - if ret == 0 { - Err(cx.unwrap_trap()) - } else { - Ok(()) - } + cx.into_result(ret) } /// Temporary state stored on the stack which is registered in the `tls` module /// below for calls into wasm. pub struct CallThreadState { - trap: Cell>, + unwind: Cell, jmp_buf: Cell<*const u8>, reset_guard_page: Cell, } +enum UnwindReason { + None, + Panic(Box), + Trap(Trap), +} + impl CallThreadState { fn new() -> CallThreadState { CallThreadState { - trap: Cell::new(None), + unwind: Cell::new(UnwindReason::None), jmp_buf: Cell::new(ptr::null()), reset_guard_page: Cell::new(false), } } - fn unwrap_trap(self) -> Trap { - self.trap - .replace(None) - .expect("unwrap_trap must be called after trap occurred") + fn into_result(self, ret: i32) -> Result<(), Trap> { + match self.unwind.replace(UnwindReason::None) { + UnwindReason::None => { + debug_assert_eq!(ret, 1); + Ok(()) + } + UnwindReason::Trap(trap) => { + debug_assert_eq!(ret, 0); + Err(trap) + } + UnwindReason::Panic(panic) => { + debug_assert_eq!(ret, 0); + std::panic::resume_unwind(panic) + } + } + } + + fn unwind_with(&self, reason: UnwindReason) -> ! { + self.unwind.replace(reason); + unsafe { + Unwind(self.jmp_buf.get()); + } } }