From 50ce19a4a4218bbdb6a5c676e0f77014cef68fd4 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Fri, 3 Sep 2021 13:40:51 -0500 Subject: [PATCH] Remove an indirect function call in `Func::new` (#3293) This commit optimizes the runtime execution of `Func::new` by removing an indirect function call that happens whenever a host function is called. This indirection was generally done to prevent monomoprhizing a lot into consumer code but the few extra functions this makes monomorphic are fairly small, and in general wasm->host call performance is pretty important. While not a massive win this is expected to improve codegen, especially because with the indirect call removed the compiler should now be able to prove more often when a `Func::new` closure doesn't panic or return an error. --- crates/wasmtime/src/func.rs | 4 +-- crates/wasmtime/src/trampoline/func.rs | 42 ++++++++++++-------------- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/crates/wasmtime/src/func.rs b/crates/wasmtime/src/func.rs index c66a43aa6e..08ce16efd1 100644 --- a/crates/wasmtime/src/func.rs +++ b/crates/wasmtime/src/func.rs @@ -1894,11 +1894,11 @@ impl HostFunc { let ty_clone = ty.clone(); // Create a trampoline that converts raw u128 values to `Val` - let func = Box::new(move |caller_vmctx, values_vec: *mut u128| unsafe { + let func = move |caller_vmctx, values_vec: *mut u128| unsafe { Caller::with(caller_vmctx, |caller| { Func::invoke(caller, &ty_clone, values_vec, &func) }) - }); + }; let (instance, trampoline) = crate::trampoline::create_function(&ty, func, engine) .expect("failed to create function"); diff --git a/crates/wasmtime/src/trampoline/func.rs b/crates/wasmtime/src/trampoline/func.rs index 7f4221357d..e215dfe8f2 100644 --- a/crates/wasmtime/src/trampoline/func.rs +++ b/crates/wasmtime/src/trampoline/func.rs @@ -12,17 +12,19 @@ use wasmtime_runtime::{ OnDemandInstanceAllocator, VMContext, VMFunctionBody, VMSharedSignatureIndex, VMTrampoline, }; -struct TrampolineState { - func: Box Result<(), Trap> + Send + Sync>, +struct TrampolineState { + func: F, #[allow(dead_code)] code_memory: CodeMemory, } -unsafe extern "C" fn stub_fn( +unsafe extern "C" fn stub_fn( vmctx: *mut VMContext, caller_vmctx: *mut VMContext, values_vec: *mut u128, -) { +) where + F: Fn(*mut VMContext, *mut u128) -> Result<(), Trap> + 'static, +{ // Here we are careful to use `catch_unwind` to ensure Rust panics don't // unwind past us. The primary reason for this is that Rust considers it UB // to unwind past an `extern "C"` function. Here we are in an `extern "C"` @@ -37,7 +39,13 @@ unsafe extern "C" fn stub_fn( // have any. To prevent leaks we avoid having any local destructors by // avoiding local variables. let result = panic::catch_unwind(AssertUnwindSafe(|| { - call_stub(vmctx, caller_vmctx, values_vec) + // Double-check ourselves in debug mode, but we control + // the `Any` here so an unsafe downcast should also + // work. + let state = (*vmctx).host_state(); + debug_assert!(state.is::>()); + let state = &*(state as *const _ as *const TrampolineState); + (state.func)(caller_vmctx, values_vec) })); match result { @@ -55,31 +63,21 @@ unsafe extern "C" fn stub_fn( // platforms. Err(panic) => wasmtime_runtime::resume_panic(panic), } - - unsafe fn call_stub( - vmctx: *mut VMContext, - caller_vmctx: *mut VMContext, - values_vec: *mut u128, - ) -> Result<(), Trap> { - let instance = InstanceHandle::from_vmctx(vmctx); - let state = &instance - .host_state() - .downcast_ref::() - .expect("state"); - (state.func)(caller_vmctx, values_vec) - } } #[cfg(compiler)] -pub fn create_function( +pub fn create_function( ft: &FuncType, - func: Box Result<(), Trap> + Send + Sync>, + func: F, engine: &Engine, -) -> Result<(InstanceHandle, VMTrampoline)> { +) -> Result<(InstanceHandle, VMTrampoline)> +where + F: Fn(*mut VMContext, *mut u128) -> Result<(), Trap> + Send + Sync + 'static, +{ let mut obj = engine.compiler().object()?; let (t1, t2) = engine.compiler().emit_trampoline_obj( ft.as_wasm_func_type(), - stub_fn as usize, + stub_fn:: as usize, &mut obj, )?; let obj = MmapVec::from_obj(obj)?;