diff --git a/crates/wasmtime/src/component/component.rs b/crates/wasmtime/src/component/component.rs index c2f57c7694..5528543855 100644 --- a/crates/wasmtime/src/component/component.rs +++ b/crates/wasmtime/src/component/component.rs @@ -1,6 +1,7 @@ use crate::signatures::SignatureCollection; use crate::{Engine, Module}; use anyhow::{bail, Context, Result}; +use std::collections::HashMap; use std::fs; use std::ops::Range; use std::path::Path; @@ -164,13 +165,33 @@ impl Component { let code = trampoline_obj.publish()?; let text = wasmtime_jit::subslice_range(code.text, code.mmap); + // This map is used to register all known tramplines in the + // `SignatureCollection` created below. This is later consulted during + // `ModuleRegistry::lookup_trampoline` if a trampoline needs to be + // located for a signature index where the original function pointer + // is that of the `trampolines` created above. + // + // This situation arises when a core wasm module imports a lowered + // function and then immediately exports it. Wasmtime will lookup an + // entry trampoline for the exported function which is actually a + // lowered host function, hence an entry in the `trampolines` variable + // above, and the type of that function will be stored in this + // `vmtrampolines` map since the type is guaranteed to have escaped + // from at least one of the modules we compiled prior. + let mut vmtrampolines = HashMap::new(); + for (_, module) in static_modules.iter() { + for (idx, trampoline, _) in module.compiled_module().trampolines() { + vmtrampolines.insert(idx, trampoline); + } + } + // FIXME: for the same reason as above where each module is // re-registering everything this should only be registered once. This // is benign for now but could do with refactorings later on. let signatures = SignatureCollection::new_for_module( engine.signatures(), types.module_types(), - [].into_iter(), + vmtrampolines.into_iter(), ); Ok(Component { @@ -202,9 +223,13 @@ impl Component { &self.inner.signatures } + pub(crate) fn text(&self) -> &[u8] { + &self.inner.trampoline_obj.mmap()[self.inner.text.clone()] + } + pub(crate) fn trampoline_ptr(&self, index: LoweredIndex) -> NonNull { let info = &self.inner.trampolines[index]; - let text = &self.inner.trampoline_obj.mmap()[self.inner.text.clone()]; + let text = self.text(); let trampoline = &text[info.start as usize..][..info.length as usize]; NonNull::new(trampoline.as_ptr() as *mut VMFunctionBody).unwrap() } diff --git a/crates/wasmtime/src/component/instance.rs b/crates/wasmtime/src/component/instance.rs index b0fe5b836e..313b68fba9 100644 --- a/crates/wasmtime/src/component/instance.rs +++ b/crates/wasmtime/src/component/instance.rs @@ -210,6 +210,7 @@ impl<'a> Instantiator<'a> { imports: &'a PrimaryMap, ) -> Instantiator<'a> { let env_component = component.env_component(); + store.modules_mut().register_component(component); Instantiator { component, imports, diff --git a/crates/wasmtime/src/instance.rs b/crates/wasmtime/src/instance.rs index bf2cf4680e..8d38339a1b 100644 --- a/crates/wasmtime/src/instance.rs +++ b/crates/wasmtime/src/instance.rs @@ -242,7 +242,7 @@ impl Instance { // Register the module just before instantiation to ensure we keep the module // properly referenced while in use by the store. - store.modules_mut().register(module); + store.modules_mut().register_module(module); // The first thing we do is issue an instance allocation request // to the instance allocator. This, on success, will give us an diff --git a/crates/wasmtime/src/module/registry.rs b/crates/wasmtime/src/module/registry.rs index 6ea18299fd..bd6e48b2c9 100644 --- a/crates/wasmtime/src/module/registry.rs +++ b/crates/wasmtime/src/module/registry.rs @@ -1,5 +1,7 @@ //! Implements a registry of modules for a store. +#[cfg(feature = "component-model")] +use crate::component::Component; use crate::{Engine, Module}; use std::{ collections::BTreeMap, @@ -24,12 +26,21 @@ lazy_static::lazy_static! { #[derive(Default)] pub struct ModuleRegistry { // Keyed by the end address of the module's code in memory. - modules_with_code: BTreeMap, + // + // The value here is the start address and the module/component it + // corresponds to. + modules_with_code: BTreeMap, // Preserved for keeping data segments alive or similar modules_without_code: Vec, } +enum ModuleOrComponent { + Module(Module), + #[cfg(feature = "component-model")] + Component(Component), +} + fn start(module: &Module) -> usize { assert!(!module.compiled_module().code().is_empty()); module.compiled_module().code().as_ptr() as usize @@ -42,15 +53,23 @@ impl ModuleRegistry { } fn module(&self, pc: usize) -> Option<&Module> { - let (end, module) = self.modules_with_code.range(pc..).next()?; - if pc < start(module) || *end < pc { + match self.module_or_component(pc)? { + ModuleOrComponent::Module(m) => Some(m), + #[cfg(feature = "component-model")] + ModuleOrComponent::Component(_) => None, + } + } + + fn module_or_component(&self, pc: usize) -> Option<&ModuleOrComponent> { + let (end, (start, module)) = self.modules_with_code.range(pc..).next()?; + if pc < *start || *end < pc { return None; } Some(module) } /// Registers a new module with the registry. - pub fn register(&mut self, module: &Module) { + pub fn register_module(&mut self, module: &Module) { let compiled_module = module.compiled_module(); // If there's not actually any functions in this module then we may @@ -61,39 +80,70 @@ impl ModuleRegistry { // modules and retain them. if compiled_module.finished_functions().len() == 0 { self.modules_without_code.push(module.clone()); + } else { + // The module code range is exclusive for end, so make it inclusive as it + // may be a valid PC value + let start_addr = start(module); + let end_addr = start_addr + compiled_module.code().len() - 1; + self.register( + start_addr, + end_addr, + ModuleOrComponent::Module(module.clone()), + ); + } + } + + #[cfg(feature = "component-model")] + pub fn register_component(&mut self, component: &Component) { + // If there's no text section associated with this component (e.g. no + // lowered functions) then there's nothing to register, otherwise it's + // registered along the same lines as modules above. + // + // Note that empty components don't need retaining here since it doesn't + // have data segments like empty modules. + let text = component.text(); + if text.is_empty() { return; } + let start = text.as_ptr() as usize; + self.register( + start, + start + text.len() - 1, + ModuleOrComponent::Component(component.clone()), + ); + } - // The module code range is exclusive for end, so make it inclusive as it - // may be a valid PC value - let start_addr = start(module); - let end_addr = start_addr + compiled_module.code().len() - 1; - + /// Registers a new module with the registry. + fn register(&mut self, start_addr: usize, end_addr: usize, item: ModuleOrComponent) { // Ensure the module isn't already present in the registry // This is expected when a module is instantiated multiple times in the // same store - if let Some(m) = self.modules_with_code.get(&end_addr) { - assert_eq!(start(m), start_addr); + if let Some((other_start, _)) = self.modules_with_code.get(&end_addr) { + assert_eq!(*other_start, start_addr); return; } // Assert that this module's code doesn't collide with any other // registered modules - if let Some((_, prev)) = self.modules_with_code.range(end_addr..).next() { - assert!(start(prev) > end_addr); + if let Some((_, (prev_start, _))) = self.modules_with_code.range(start_addr..).next() { + assert!(*prev_start > end_addr); } if let Some((prev_end, _)) = self.modules_with_code.range(..=start_addr).next_back() { assert!(*prev_end < start_addr); } - let prev = self.modules_with_code.insert(end_addr, module.clone()); + let prev = self.modules_with_code.insert(end_addr, (start_addr, item)); assert!(prev.is_none()); } /// Looks up a trampoline from an anyfunc. pub fn lookup_trampoline(&self, anyfunc: &VMCallerCheckedAnyfunc) -> Option { - let module = self.module(anyfunc.func_ptr.as_ptr() as usize)?; - module.signatures().trampoline(anyfunc.type_index) + let signatures = match self.module_or_component(anyfunc.func_ptr.as_ptr() as usize)? { + ModuleOrComponent::Module(m) => m.signatures(), + #[cfg(feature = "component-model")] + ModuleOrComponent::Component(c) => c.signatures(), + }; + signatures.trampoline(anyfunc.type_index) } } diff --git a/crates/wasmtime/src/signatures.rs b/crates/wasmtime/src/signatures.rs index 1a9dabd299..d3c65eb87f 100644 --- a/crates/wasmtime/src/signatures.rs +++ b/crates/wasmtime/src/signatures.rs @@ -19,7 +19,7 @@ use wasmtime_runtime::{VMSharedSignatureIndex, VMTrampoline}; pub struct SignatureCollection { registry: Arc>, signatures: PrimaryMap, - trampolines: HashMap, + trampolines: HashMap, } impl SignatureCollection { @@ -59,9 +59,7 @@ impl SignatureCollection { /// Gets a trampoline for a registered signature. pub fn trampoline(&self, index: VMSharedSignatureIndex) -> Option { - self.trampolines - .get(&index) - .map(|(_, trampoline)| *trampoline) + self.trampolines.get(&index).copied() } } @@ -93,7 +91,7 @@ impl SignatureRegistryInner { trampolines: impl Iterator, ) -> ( PrimaryMap, - HashMap, + HashMap, ) { let mut sigs = PrimaryMap::default(); let mut map = HashMap::default(); @@ -104,7 +102,7 @@ impl SignatureRegistryInner { } for (index, trampoline) in trampolines { - map.insert(sigs[index], (1, trampoline)); + map.insert(sigs[index], trampoline); } (sigs, map) @@ -165,8 +163,8 @@ impl SignatureRegistryInner { } else { // Otherwise, use the trampolines map, which has reference counts related // to the stored index - for (index, (count, _)) in collection.trampolines.iter() { - self.unregister_entry(*index, *count); + for (index, _) in collection.trampolines.iter() { + self.unregister_entry(*index, 1); } } } diff --git a/tests/all/component_model/func.rs b/tests/all/component_model/func.rs index f5666d0ac9..23024a468a 100644 --- a/tests/all/component_model/func.rs +++ b/tests/all/component_model/func.rs @@ -3,7 +3,7 @@ use anyhow::Result; use std::rc::Rc; use std::sync::Arc; use wasmtime::component::*; -use wasmtime::{Store, Trap, TrapCode}; +use wasmtime::{Store, StoreContextMut, Trap, TrapCode}; const CANON_32BIT_NAN: u32 = 0b01111111110000000000000000000000; const CANON_64BIT_NAN: u64 = 0b0111111111111000000000000000000000000000000000000000000000000000; @@ -1858,3 +1858,58 @@ fn invalid_alignment() -> Result<()> { Ok(()) } + +#[test] +fn drop_component_still_works() -> Result<()> { + let component = r#" + (component + (import "f" (func $f)) + + (core func $f_lower + (canon lower (func $f)) + ) + (core module $m + (import "" "" (func $f)) + + (func $f2 + call $f + call $f + ) + + (export "f" (func $f2)) + ) + (core instance $i (instantiate $m + (with "" (instance + (export "" (func $f_lower)) + )) + )) + (func (export "f") + (canon lift + (core func $i "f") + ) + ) + ) + "#; + + let (mut store, instance) = { + let engine = super::engine(); + let component = Component::new(&engine, component)?; + let mut store = Store::new(&engine, 0); + let mut linker = Linker::new(&engine); + linker + .root() + .func_wrap("f", |mut store: StoreContextMut<'_, u32>| -> Result<()> { + *store.data_mut() += 1; + Ok(()) + })?; + let instance = linker.instantiate(&mut store, &component)?; + (store, instance) + }; + + let f = instance.get_typed_func::<(), (), _>(&mut store, "f")?; + assert_eq!(*store.data(), 0); + f.call(&mut store, ())?; + assert_eq!(*store.data(), 2); + + Ok(()) +} diff --git a/tests/all/component_model/import.rs b/tests/all/component_model/import.rs index 00a516d8ff..2385337f18 100644 --- a/tests/all/component_model/import.rs +++ b/tests/all/component_model/import.rs @@ -614,3 +614,50 @@ fn bad_import_alignment() -> Result<()> { assert!(trap.to_string().contains("pointer not aligned"), "{}", trap); Ok(()) } + +#[test] +fn no_actual_wasm_code() -> Result<()> { + let component = r#" + (component + (import "f" (func $f)) + + (core func $f_lower + (canon lower (func $f)) + ) + (core module $m + (import "" "" (func $f)) + (export "f" (func $f)) + ) + (core instance $i (instantiate $m + (with "" (instance + (export "" (func $f_lower)) + )) + )) + (func (export "thunk") + (canon lift + (core func $i "f") + ) + ) + ) + "#; + + let engine = super::engine(); + let component = Component::new(&engine, component)?; + let mut store = Store::new(&engine, 0); + let mut linker = Linker::new(&engine); + linker + .root() + .func_wrap("f", |mut store: StoreContextMut<'_, u32>| -> Result<()> { + *store.data_mut() += 1; + Ok(()) + })?; + + let instance = linker.instantiate(&mut store, &component)?; + let thunk = instance.get_typed_func::<(), (), _>(&mut store, "thunk")?; + + assert_eq!(*store.data(), 0); + thunk.call(&mut store, ())?; + assert_eq!(*store.data(), 1); + + Ok(()) +}