Implement canon lower of a canon lift function in the same component (#4347)

* Implement `canon lower` of a `canon lift` function in the same component

This commit implements the "degenerate" logic for implementing a
function within a component that is lifted and then immediately lowered
again. In this situation the lowered function will immediately generate
a trap and doesn't need to implement anything else.

The implementation in this commit is somewhat heavyweight but I think is
probably justified moreso in future additions to the component model
rather than what exactly is here right now. It's not expected that this
"always trap" functionality will really be used all that often since it
would generally mean a buggy component, but the functionality plumbed
through here is hopefully going to be useful for implementing
component-to-component adapter trampolines.

Specifically this commit implements a strategy where the `canon.lower`'d
function is generated by Cranelift and simply has a single trap
instruction when called, doing nothing else. The main complexity comes
from juggling around all the data associated with these functions,
primarily plumbing through the traps into the `ModuleRegistry` to
ensure that the global `is_wasm_trap_pc` function returns `true` and at
runtime when we lookup information about the trap it's all readily
available (e.g. translating the trapping pc to a `TrapCode`).

* Fix non-component build

* Fix some offset calculations

* Only create one "always trap" per signature

Use an internal map to deduplicate during compilation.
This commit is contained in:
Alex Crichton
2022-06-29 11:35:37 -05:00
committed by GitHub
parent 22fb3ecbbf
commit f0278c5db7
16 changed files with 664 additions and 164 deletions

View File

@@ -1,6 +1,7 @@
use crate::signatures::SignatureCollection;
use crate::{Engine, Module};
use anyhow::{bail, Context, Result};
use std::any::Any;
use std::collections::HashMap;
use std::collections::HashSet;
use std::fs;
@@ -9,9 +10,10 @@ use std::path::Path;
use std::ptr::NonNull;
use std::sync::Arc;
use wasmtime_environ::component::{
ComponentTypes, GlobalInitializer, LoweredIndex, LoweringInfo, StaticModuleIndex, Translator,
AlwaysTrapInfo, ComponentTypes, GlobalInitializer, LoweredIndex, LoweringInfo,
RuntimeAlwaysTrapIndex, StaticModuleIndex, Translator,
};
use wasmtime_environ::PrimaryMap;
use wasmtime_environ::{PrimaryMap, SignatureIndex, Trampoline, TrapCode};
use wasmtime_jit::CodeMemory;
use wasmtime_runtime::VMFunctionBody;
@@ -41,19 +43,27 @@ struct ComponentInner {
/// this field.
types: Arc<ComponentTypes>,
/// The in-memory ELF image of the compiled trampolines for this component.
///
/// This is currently only used for wasm-to-host trampolines when
/// `canon.lower` is encountered.
/// The in-memory ELF image of the compiled functions for this component.
trampoline_obj: CodeMemory,
/// The index ranges within `trampoline_obj`'s mmap memory for the entire
/// text section.
text: Range<usize>,
/// Where trampolines are located within the `text` section of
/// `trampoline_obj`.
trampolines: PrimaryMap<LoweredIndex, LoweringInfo>,
/// Where lowered function trampolines are located within the `text`
/// section of `trampoline_obj`.
///
/// These trampolines are the function pointer within the
/// `VMCallerCheckedAnyfunc` and will delegate indirectly to a host function
/// pointer when called.
lowerings: PrimaryMap<LoweredIndex, LoweringInfo>,
/// Where the "always trap" functions are located within the `text` section
/// of `trampoline_obj`.
///
/// These functions are "degenerate functions" here solely to implement
/// functions that are `canon lift`'d then immediately `canon lower`'d.
always_trap: PrimaryMap<RuntimeAlwaysTrapIndex, AlwaysTrapInfo>,
}
impl Component {
@@ -117,39 +127,10 @@ impl Component {
.context("failed to parse WebAssembly module")?;
let types = Arc::new(types.finish());
// All lowered functions will require a trampoline to be available in
// case they're used when entering wasm. For example a lowered function
// could be immediately lifted in which case we'll need a trampoline to
// call that lowered function.
//
// Most of the time trampolines can come from the core wasm modules
// since lifted functions come from core wasm. For these esoteric cases
// though we may have to compile trampolines specifically into the
// component object as well in case core wasm doesn't provide the
// necessary trampoline.
let lowerings = component
.initializers
.iter()
.filter_map(|init| match init {
GlobalInitializer::LowerImport(i) => Some(i),
_ => None,
})
.collect::<Vec<_>>();
let required_trampolines = lowerings
.iter()
.map(|l| l.canonical_abi)
.collect::<HashSet<_>>();
let provided_trampolines = modules
.iter()
.flat_map(|(_, m)| m.exported_signatures.iter().copied())
.collect::<HashSet<_>>();
let mut trampolines_to_compile = required_trampolines
.difference(&provided_trampolines)
.collect::<Vec<_>>();
// Ensure a deterministically compiled artifact by sorting this list
// which was otherwise created with nondeterministically ordered hash
// tables.
trampolines_to_compile.sort();
let (static_modules, trampolines) = engine.join_maybe_parallel(
// In one (possibly) parallel task all the modules found within this
@@ -173,41 +154,10 @@ impl Component {
},
// In another (possibly) parallel task we compile lowering
// trampolines necessary found in the component.
|| -> Result<_> {
let compiler = engine.compiler();
let (lowered_trampolines, core_trampolines) = engine.join_maybe_parallel(
// Compile all the lowered trampolines here which implement
// `canon lower` and are used to exit wasm into the host.
|| -> Result<_> {
Ok(engine
.run_maybe_parallel(lowerings, |lowering| {
compiler
.component_compiler()
.compile_lowered_trampoline(&component, lowering, &types)
})?
.into_iter()
.collect())
},
// Compile all entry host-to-wasm trampolines here that
// aren't otherwise provided by core wasm modules.
|| -> Result<_> {
engine.run_maybe_parallel(trampolines_to_compile.clone(), |i| {
let ty = &types[*i];
Ok((*i, compiler.compile_host_to_wasm_trampoline(ty)?))
})
},
);
let mut obj = engine.compiler().object()?;
let trampolines = compiler.component_compiler().emit_obj(
lowered_trampolines?,
core_trampolines?,
&mut obj,
)?;
Ok((trampolines, wasmtime_jit::mmap_vec_from_obj(obj)?))
},
|| Component::compile_component(engine, &component, &types, &provided_trampolines),
);
let static_modules = static_modules?;
let ((lowering_trampolines, core_trampolines), trampoline_obj) = trampolines?;
let (lowerings, always_trap, trampolines, trampoline_obj) = trampolines?;
let mut trampoline_obj = CodeMemory::new(trampoline_obj);
let code = trampoline_obj.publish()?;
let text = wasmtime_jit::subslice_range(code.text, code.mmap);
@@ -231,8 +181,8 @@ impl Component {
vmtrampolines.insert(idx, trampoline);
}
}
for (signature, trampoline) in trampolines_to_compile.iter().zip(core_trampolines) {
vmtrampolines.insert(**signature, unsafe {
for trampoline in trampolines {
vmtrampolines.insert(trampoline.signature, unsafe {
let ptr =
code.text[trampoline.start as usize..][..trampoline.length as usize].as_ptr();
std::mem::transmute::<*const u8, wasmtime_runtime::VMTrampoline>(ptr)
@@ -248,6 +198,15 @@ impl Component {
vmtrampolines.into_iter(),
);
// Assert that this `always_trap` list is sorted which is relied on in
// `register_component` as well as `Component::lookup_trap_code` below.
assert!(always_trap
.values()
.as_slice()
.windows(2)
.all(|window| { window[0].start < window[1].start }));
crate::module::register_component(code.text, &always_trap);
Ok(Component {
inner: Arc::new(ComponentInner {
component,
@@ -256,11 +215,136 @@ impl Component {
signatures,
trampoline_obj,
text,
trampolines: lowering_trampolines,
lowerings,
always_trap,
}),
})
}
#[cfg(compiler)]
fn compile_component(
engine: &Engine,
component: &wasmtime_environ::component::Component,
types: &ComponentTypes,
provided_trampolines: &HashSet<SignatureIndex>,
) -> Result<(
PrimaryMap<LoweredIndex, LoweringInfo>,
PrimaryMap<RuntimeAlwaysTrapIndex, AlwaysTrapInfo>,
Vec<Trampoline>,
wasmtime_runtime::MmapVec,
)> {
let results = engine.join_maybe_parallel(
|| compile_lowerings(engine, component, types),
|| -> Result<_> {
Ok(engine.join_maybe_parallel(
|| compile_always_trap(engine, component, types),
|| compile_trampolines(engine, component, types, provided_trampolines),
))
},
);
let (lowerings, other) = results;
let (always_trap, trampolines) = other?;
let mut obj = engine.compiler().object()?;
let (lower, traps, trampolines) = engine.compiler().component_compiler().emit_obj(
lowerings?,
always_trap?,
trampolines?,
&mut obj,
)?;
return Ok((
lower,
traps,
trampolines,
wasmtime_jit::mmap_vec_from_obj(obj)?,
));
fn compile_lowerings(
engine: &Engine,
component: &wasmtime_environ::component::Component,
types: &ComponentTypes,
) -> Result<PrimaryMap<LoweredIndex, Box<dyn Any + Send>>> {
let lowerings = component
.initializers
.iter()
.filter_map(|init| match init {
GlobalInitializer::LowerImport(i) => Some(i),
_ => None,
})
.collect::<Vec<_>>();
Ok(engine
.run_maybe_parallel(lowerings, |lowering| {
engine
.compiler()
.component_compiler()
.compile_lowered_trampoline(&component, lowering, &types)
})?
.into_iter()
.collect())
}
fn compile_always_trap(
engine: &Engine,
component: &wasmtime_environ::component::Component,
types: &ComponentTypes,
) -> Result<PrimaryMap<RuntimeAlwaysTrapIndex, Box<dyn Any + Send>>> {
let always_trap = component
.initializers
.iter()
.filter_map(|init| match init {
GlobalInitializer::AlwaysTrap(i) => Some(i),
_ => None,
})
.collect::<Vec<_>>();
Ok(engine
.run_maybe_parallel(always_trap, |info| {
engine
.compiler()
.component_compiler()
.compile_always_trap(&types[info.canonical_abi])
})?
.into_iter()
.collect())
}
fn compile_trampolines(
engine: &Engine,
component: &wasmtime_environ::component::Component,
types: &ComponentTypes,
provided_trampolines: &HashSet<SignatureIndex>,
) -> Result<Vec<(SignatureIndex, Box<dyn Any + Send>)>> {
// All lowered functions will require a trampoline to be available in
// case they're used when entering wasm. For example a lowered function
// could be immediately lifted in which case we'll need a trampoline to
// call that lowered function.
//
// Most of the time trampolines can come from the core wasm modules
// since lifted functions come from core wasm. For these esoteric cases
// though we may have to compile trampolines specifically into the
// component object as well in case core wasm doesn't provide the
// necessary trampoline.
let required_trampolines = component
.initializers
.iter()
.filter_map(|init| match init {
GlobalInitializer::LowerImport(i) => Some(i.canonical_abi),
GlobalInitializer::AlwaysTrap(i) => Some(i.canonical_abi),
_ => None,
})
.collect::<HashSet<_>>();
let mut trampolines_to_compile = required_trampolines
.difference(&provided_trampolines)
.collect::<Vec<_>>();
// Ensure a deterministically compiled artifact by sorting this list
// which was otherwise created with nondeterministically ordered hash
// tables.
trampolines_to_compile.sort();
engine.run_maybe_parallel(trampolines_to_compile.clone(), |i| {
let ty = &types[*i];
Ok((*i, engine.compiler().compile_host_to_wasm_trampoline(ty)?))
})
}
}
pub(crate) fn env_component(&self) -> &wasmtime_environ::component::Component {
&self.inner.component
}
@@ -278,13 +362,52 @@ impl Component {
}
pub(crate) fn text(&self) -> &[u8] {
&self.inner.trampoline_obj.mmap()[self.inner.text.clone()]
self.inner.text()
}
pub(crate) fn trampoline_ptr(&self, index: LoweredIndex) -> NonNull<VMFunctionBody> {
let info = &self.inner.trampolines[index];
pub(crate) fn lowering_ptr(&self, index: LoweredIndex) -> NonNull<VMFunctionBody> {
let info = &self.inner.lowerings[index];
self.func(info.start, info.length)
}
pub(crate) fn always_trap_ptr(&self, index: RuntimeAlwaysTrapIndex) -> NonNull<VMFunctionBody> {
let info = &self.inner.always_trap[index];
self.func(info.start, info.length)
}
fn func(&self, start: u32, len: u32) -> NonNull<VMFunctionBody> {
let text = self.text();
let trampoline = &text[info.start as usize..][..info.length as usize];
let trampoline = &text[start as usize..][..len as usize];
NonNull::new(trampoline.as_ptr() as *mut VMFunctionBody).unwrap()
}
/// Looks up a trap code for the instruction at `offset` where the offset
/// specified is relative to the start of this component's text section.
pub(crate) fn lookup_trap_code(&self, offset: usize) -> Option<TrapCode> {
let offset = u32::try_from(offset).ok()?;
// Currently traps only come from "always trap" adapters so that map is
// the only map that's searched.
match self
.inner
.always_trap
.values()
.as_slice()
.binary_search_by_key(&offset, |info| info.start + info.trap_offset)
{
Ok(_) => Some(TrapCode::AlwaysTrapAdapter),
Err(_) => None,
}
}
}
impl ComponentInner {
fn text(&self) -> &[u8] {
&self.trampoline_obj.mmap()[self.text.clone()]
}
}
impl Drop for ComponentInner {
fn drop(&mut self) {
crate::module::unregister_component(self.text());
}
}

View File

@@ -7,9 +7,9 @@ use anyhow::{anyhow, Context, Result};
use std::marker;
use std::sync::Arc;
use wasmtime_environ::component::{
ComponentTypes, CoreDef, CoreExport, Export, ExportItem, ExtractMemory, ExtractPostReturn,
ExtractRealloc, GlobalInitializer, InstantiateModule, LowerImport, RuntimeImportIndex,
RuntimeInstanceIndex, RuntimeModuleIndex,
AlwaysTrap, ComponentTypes, CoreDef, CoreExport, Export, ExportItem, ExtractMemory,
ExtractPostReturn, ExtractRealloc, GlobalInitializer, InstantiateModule, LowerImport,
RuntimeImportIndex, RuntimeInstanceIndex, RuntimeModuleIndex,
};
use wasmtime_environ::{EntityIndex, PrimaryMap};
use wasmtime_runtime::component::{ComponentInstance, OwnedComponentInstance};
@@ -145,12 +145,17 @@ impl InstanceData {
pub fn lookup_def(&self, store: &mut StoreOpaque, def: &CoreDef) -> wasmtime_runtime::Export {
match def {
CoreDef::Export(e) => self.lookup_export(store, e),
CoreDef::Lowered(idx) => {
wasmtime_runtime::Export::Function(wasmtime_runtime::ExportFunction {
anyfunc: self.state.lowering_anyfunc(*idx),
})
}
CoreDef::Export(e) => self.lookup_export(store, e),
CoreDef::AlwaysTrap(idx) => {
wasmtime_runtime::Export::Function(wasmtime_runtime::ExportFunction {
anyfunc: self.state.always_trap_anyfunc(*idx),
})
}
}
}
@@ -272,6 +277,8 @@ impl<'a> Instantiator<'a> {
GlobalInitializer::LowerImport(import) => self.lower_import(import),
GlobalInitializer::AlwaysTrap(trap) => self.always_trap(trap),
GlobalInitializer::ExtractMemory(mem) => self.extract_memory(store.0, mem),
GlobalInitializer::ExtractRealloc(realloc) => {
@@ -307,7 +314,7 @@ impl<'a> Instantiator<'a> {
self.data.state.set_lowering(
import.index,
func.lowering(),
self.component.trampoline_ptr(import.index),
self.component.lowering_ptr(import.index),
self.component
.signatures()
.shared_signature(import.canonical_abi)
@@ -324,6 +331,17 @@ impl<'a> Instantiator<'a> {
self.data.funcs.push(func.clone());
}
fn always_trap(&mut self, trap: &AlwaysTrap) {
self.data.state.set_always_trap(
trap.index,
self.component.always_trap_ptr(trap.index),
self.component
.signatures()
.shared_signature(trap.canonical_abi)
.expect("found unregistered signature"),
);
}
fn extract_memory(&mut self, store: &mut StoreOpaque, memory: &ExtractMemory) {
let mem = match self.data.lookup_export(store, &memory.export) {
wasmtime_runtime::Export::Memory(m) => m,

View File

@@ -26,6 +26,8 @@ mod registry;
mod serialization;
pub use registry::{is_wasm_trap_pc, ModuleRegistry};
#[cfg(feature = "component-model")]
pub use registry::{register_component, unregister_component};
pub use serialization::SerializedModule;
/// A compiled WebAssembly module, ready to be instantiated.
@@ -537,7 +539,7 @@ impl Module {
// into the global registry of modules so we can resolve traps
// appropriately. Note that the corresponding `unregister` happens below
// in `Drop for ModuleInner`.
registry::register(&module);
registry::register_module(&module);
Ok(Self {
inner: Arc::new(ModuleInner {
@@ -987,7 +989,7 @@ impl wasmtime_runtime::ModuleInfo for ModuleInner {
impl Drop for ModuleInner {
fn drop(&mut self) {
registry::unregister(&self.module);
registry::unregister_module(&self.module);
}
}

View File

@@ -8,6 +8,11 @@ use std::{
sync::{Arc, RwLock},
};
use wasmtime_environ::TrapCode;
#[cfg(feature = "component-model")]
use wasmtime_environ::{
component::{AlwaysTrapInfo, RuntimeAlwaysTrapIndex},
PrimaryMap,
};
use wasmtime_jit::CompiledModule;
use wasmtime_runtime::{ModuleInfo, VMCallerCheckedAnyfunc, VMTrampoline};
@@ -15,7 +20,7 @@ use wasmtime_runtime::{ModuleInfo, VMCallerCheckedAnyfunc, VMTrampoline};
///
/// Note that the primary reason for this registry is to ensure that everything
/// in `Module` is kept alive for the duration of a `Store`. At this time we
/// need "basically everything" within a `Moudle` to stay alive once it's
/// need "basically everything" within a `Module` to stay alive once it's
/// instantiated within a store. While there's some smaller portions that could
/// theoretically be omitted as they're not needed by the store they're
/// currently small enough to not worry much about.
@@ -147,8 +152,13 @@ impl ModuleRegistry {
/// Fetches trap information about a program counter in a backtrace.
pub fn lookup_trap_code(&self, pc: usize) -> Option<TrapCode> {
let (module, offset) = self.module(pc)?;
wasmtime_environ::lookup_trap_code(module.compiled_module().trap_data(), offset)
match self.module_or_component(pc)? {
(ModuleOrComponent::Module(module), offset) => {
wasmtime_environ::lookup_trap_code(module.compiled_module().trap_data(), offset)
}
#[cfg(feature = "component-model")]
(ModuleOrComponent::Component(component), offset) => component.lookup_trap_code(offset),
}
}
/// Fetches frame information about a program counter in a backtrace.
@@ -160,9 +170,21 @@ impl ModuleRegistry {
/// boolean indicates whether the engine used to compile this module is
/// using environment variables to control debuginfo parsing.
pub(crate) fn lookup_frame_info(&self, pc: usize) -> Option<(FrameInfo, &Module)> {
let (module, offset) = self.module(pc)?;
let info = FrameInfo::new(module, offset)?;
Some((info, module))
match self.module_or_component(pc)? {
(ModuleOrComponent::Module(module), offset) => {
let info = FrameInfo::new(module, offset)?;
Some((info, module))
}
#[cfg(feature = "component-model")]
(ModuleOrComponent::Component(_), _) => {
// FIXME: should investigate whether it's worth preserving
// frame information on a `Component` to resolve a frame here.
// Note that this can be traced back to either a lowered
// function via a trampoline or an "always trap" function at
// this time which may be useful debugging information to have.
None
}
}
}
}
@@ -183,12 +205,19 @@ lazy_static::lazy_static! {
static ref GLOBAL_MODULES: RwLock<GlobalModuleRegistry> = Default::default();
}
type GlobalModuleRegistry = BTreeMap<usize, (usize, Arc<CompiledModule>)>;
type GlobalModuleRegistry = BTreeMap<usize, (usize, TrapInfo)>;
#[derive(Clone)]
enum TrapInfo {
Module(Arc<CompiledModule>),
#[cfg(feature = "component-model")]
Component(Arc<Vec<u32>>),
}
/// Returns whether the `pc`, according to globally registered information,
/// is a wasm trap or not.
pub fn is_wasm_trap_pc(pc: usize) -> bool {
let (module, text_offset) = {
let (trap_info, text_offset) = {
let all_modules = GLOBAL_MODULES.read().unwrap();
let (end, (start, module)) = match all_modules.range(pc..).next() {
@@ -201,7 +230,16 @@ pub fn is_wasm_trap_pc(pc: usize) -> bool {
(module.clone(), pc - *start)
};
wasmtime_environ::lookup_trap_code(module.trap_data(), text_offset).is_some()
match trap_info {
TrapInfo::Module(module) => {
wasmtime_environ::lookup_trap_code(module.trap_data(), text_offset).is_some()
}
#[cfg(feature = "component-model")]
TrapInfo::Component(traps) => {
let offset = u32::try_from(text_offset).unwrap();
traps.binary_search(&offset).is_ok()
}
}
}
/// Registers a new region of code.
@@ -212,7 +250,7 @@ pub fn is_wasm_trap_pc(pc: usize) -> bool {
/// This is required to enable traps to work correctly since the signal handler
/// will lookup in the `GLOBAL_MODULES` list to determine which a particular pc
/// is a trap or not.
pub fn register(module: &Arc<CompiledModule>) {
pub fn register_module(module: &Arc<CompiledModule>) {
let code = module.code();
if code.is_empty() {
return;
@@ -222,14 +260,14 @@ pub fn register(module: &Arc<CompiledModule>) {
let prev = GLOBAL_MODULES
.write()
.unwrap()
.insert(end, (start, module.clone()));
.insert(end, (start, TrapInfo::Module(module.clone())));
assert!(prev.is_none());
}
/// Unregisters a module from the global map.
///
/// Must hae been previously registered with `register`.
pub fn unregister(module: &Arc<CompiledModule>) {
/// Must have been previously registered with `register`.
pub fn unregister_module(module: &Arc<CompiledModule>) {
let code = module.code();
if code.is_empty() {
return;
@@ -239,6 +277,39 @@ pub fn unregister(module: &Arc<CompiledModule>) {
assert!(module.is_some());
}
/// Same as `register_module`, but for components
#[cfg(feature = "component-model")]
pub fn register_component(text: &[u8], traps: &PrimaryMap<RuntimeAlwaysTrapIndex, AlwaysTrapInfo>) {
if text.is_empty() {
return;
}
let start = text.as_ptr() as usize;
let end = start + text.len();
let info = Arc::new(
traps
.iter()
.map(|(_, info)| info.start + info.trap_offset)
.collect::<Vec<_>>(),
);
let prev = GLOBAL_MODULES
.write()
.unwrap()
.insert(end, (start, TrapInfo::Component(info)));
assert!(prev.is_none());
}
/// Same as `unregister_module`, but for components
#[cfg(feature = "component-model")]
pub fn unregister_component(text: &[u8]) {
if text.is_empty() {
return;
}
let start = text.as_ptr() as usize;
let end = start + text.len();
let info = GLOBAL_MODULES.write().unwrap().remove(&end);
assert!(info.is_some());
}
#[test]
fn test_frame_info() -> Result<(), anyhow::Error> {
use crate::*;

View File

@@ -87,6 +87,13 @@ pub enum TrapCode {
/// Execution has potentially run too long and may be interrupted.
Interrupt,
/// When the `component-model` feature is enabled this trap represents a
/// function that was `canon lift`'d, then `canon lower`'d, then called.
/// This combination of creation of a function in the component model
/// generates a function that always traps and, when called, produces this
/// flavor of trap.
AlwaysTrapAdapter,
}
impl TrapCode {
@@ -104,6 +111,7 @@ impl TrapCode {
EnvTrapCode::BadConversionToInteger => TrapCode::BadConversionToInteger,
EnvTrapCode::UnreachableCodeReached => TrapCode::UnreachableCodeReached,
EnvTrapCode::Interrupt => TrapCode::Interrupt,
EnvTrapCode::AlwaysTrapAdapter => TrapCode::AlwaysTrapAdapter,
}
}
}
@@ -123,6 +131,7 @@ impl fmt::Display for TrapCode {
BadConversionToInteger => "invalid conversion to integer",
UnreachableCodeReached => "wasm `unreachable` instruction executed",
Interrupt => "interrupt",
AlwaysTrapAdapter => "degenerate component adapter called",
};
write!(f, "{}", desc)
}