diff --git a/crates/runtime/src/instance/allocator.rs b/crates/runtime/src/instance/allocator.rs index c45270e173..aae17c7f4d 100644 --- a/crates/runtime/src/instance/allocator.rs +++ b/crates/runtime/src/instance/allocator.rs @@ -6,8 +6,8 @@ use crate::table::{Table, TableElement}; use crate::traphandlers::Trap; use crate::vmcontext::{ VMBuiltinFunctionsArray, VMCallerCheckedAnyfunc, VMContext, VMFunctionBody, VMFunctionImport, - VMGlobalDefinition, VMGlobalImport, VMInterrupts, VMMemoryDefinition, VMMemoryImport, - VMSharedSignatureIndex, VMTableDefinition, VMTableImport, + VMGlobalDefinition, VMGlobalImport, VMInterrupts, VMMemoryImport, VMSharedSignatureIndex, + VMTableImport, }; use std::alloc; use std::any::Any; @@ -391,31 +391,18 @@ fn initialize_instance( Ok(()) } -unsafe fn initialize_vmcontext( - instance: &Instance, - functions: &[VMFunctionImport], - tables: &[VMTableImport], - memories: &[VMMemoryImport], - globals: &[VMGlobalImport], - finished_functions: &PrimaryMap, - lookup_shared_signature: &dyn Fn(SignatureIndex) -> VMSharedSignatureIndex, - interrupts: *const VMInterrupts, - externref_activations_table: *mut VMExternRefActivationsTable, - stack_map_registry: *mut StackMapRegistry, - get_mem_def: impl Fn(DefinedMemoryIndex) -> VMMemoryDefinition, - get_table_def: impl Fn(DefinedTableIndex) -> VMTableDefinition, -) { +unsafe fn initialize_vmcontext(instance: &Instance, req: InstanceAllocationRequest) { let module = &instance.module; - *instance.interrupts() = interrupts; - *instance.externref_activations_table() = externref_activations_table; - *instance.stack_map_registry() = stack_map_registry; + *instance.interrupts() = req.interrupts; + *instance.externref_activations_table() = req.externref_activations_table; + *instance.stack_map_registry() = req.stack_map_registry; // Initialize shared signatures let mut ptr = instance.signature_ids_ptr(); for sig in module.types.values() { *ptr = match sig { - ModuleType::Function(sig) => lookup_shared_signature(*sig), + ModuleType::Function(sig) => (req.lookup_shared_signature)(*sig), _ => VMSharedSignatureIndex::new(u32::max_value()), }; ptr = ptr.add(1); @@ -428,38 +415,38 @@ unsafe fn initialize_vmcontext( ); // Initialize the imports - debug_assert_eq!(functions.len(), module.num_imported_funcs); + debug_assert_eq!(req.imports.functions.len(), module.num_imported_funcs); ptr::copy( - functions.as_ptr(), + req.imports.functions.as_ptr(), instance.imported_functions_ptr() as *mut VMFunctionImport, - functions.len(), + req.imports.functions.len(), ); - debug_assert_eq!(tables.len(), module.num_imported_tables); + debug_assert_eq!(req.imports.tables.len(), module.num_imported_tables); ptr::copy( - tables.as_ptr(), + req.imports.tables.as_ptr(), instance.imported_tables_ptr() as *mut VMTableImport, - tables.len(), + req.imports.tables.len(), ); - debug_assert_eq!(memories.len(), module.num_imported_memories); + debug_assert_eq!(req.imports.memories.len(), module.num_imported_memories); ptr::copy( - memories.as_ptr(), + req.imports.memories.as_ptr(), instance.imported_memories_ptr() as *mut VMMemoryImport, - memories.len(), + req.imports.memories.len(), ); - debug_assert_eq!(globals.len(), module.num_imported_globals); + debug_assert_eq!(req.imports.globals.len(), module.num_imported_globals); ptr::copy( - globals.as_ptr(), + req.imports.globals.as_ptr(), instance.imported_globals_ptr() as *mut VMGlobalImport, - globals.len(), + req.imports.globals.len(), ); // Initialize the functions for (index, sig) in instance.module.functions.iter() { - let type_index = lookup_shared_signature(*sig); + let type_index = (req.lookup_shared_signature)(*sig); let (func_ptr, vmctx) = if let Some(def_index) = instance.module.defined_func_index(index) { ( - NonNull::new(finished_functions[def_index] as *mut _).unwrap(), + NonNull::new(req.finished_functions[def_index] as *mut _).unwrap(), instance.vmctx_ptr(), ) } else { @@ -480,14 +467,17 @@ unsafe fn initialize_vmcontext( // Initialize the defined tables let mut ptr = instance.tables_ptr(); for i in 0..module.table_plans.len() - module.num_imported_tables { - ptr::write(ptr, get_table_def(DefinedTableIndex::new(i))); + ptr::write(ptr, instance.tables[DefinedTableIndex::new(i)].vmtable()); ptr = ptr.add(1); } // Initialize the defined memories let mut ptr = instance.memories_ptr(); for i in 0..module.memory_plans.len() - module.num_imported_memories { - ptr::write(ptr, get_mem_def(DefinedMemoryIndex::new(i))); + ptr::write( + ptr, + instance.memories[DefinedMemoryIndex::new(i)].vmmemory(), + ); ptr = ptr.add(1); } @@ -577,7 +567,7 @@ impl OnDemandInstanceAllocator { unsafe impl InstanceAllocator for OnDemandInstanceAllocator { unsafe fn allocate( &self, - req: InstanceAllocationRequest, + mut req: InstanceAllocationRequest, ) -> Result { debug_assert!(!req.externref_activations_table.is_null()); debug_assert!(!req.stack_map_registry.is_null()); @@ -585,6 +575,8 @@ unsafe impl InstanceAllocator for OnDemandInstanceAllocator { let memories = self.create_memories(&req.module)?; let tables = Self::create_tables(&req.module); + let host_state = std::mem::replace(&mut req.host_state, Box::new(())); + let handle = { let instance = Instance { module: req.module.clone(), @@ -595,7 +587,7 @@ unsafe impl InstanceAllocator for OnDemandInstanceAllocator { req.module.passive_elements.len(), )), dropped_data: RefCell::new(EntitySet::with_capacity(req.module.passive_data.len())), - host_state: req.host_state, + host_state, #[cfg(all(feature = "uffd", target_os = "linux"))] guard_page_faults: RefCell::new(Vec::new()), vmctx: VMContext {}, @@ -609,21 +601,7 @@ unsafe impl InstanceAllocator for OnDemandInstanceAllocator { InstanceHandle::new(instance_ptr) }; - let instance = handle.instance(); - initialize_vmcontext( - instance, - req.imports.functions, - req.imports.tables, - req.imports.memories, - req.imports.globals, - req.finished_functions, - req.lookup_shared_signature, - req.interrupts, - req.externref_activations_table, - req.stack_map_registry, - &|index| instance.memories[index].vmmemory(), - &|index| instance.tables[index].vmtable(), - ); + initialize_vmcontext(handle.instance(), req); Ok(handle) } diff --git a/crates/runtime/src/instance/allocator/pooling.rs b/crates/runtime/src/instance/allocator/pooling.rs index dfcbd7bb0a..394b231c62 100644 --- a/crates/runtime/src/instance/allocator/pooling.rs +++ b/crates/runtime/src/instance/allocator/pooling.rs @@ -430,7 +430,7 @@ impl InstancePool { fn allocate( &self, strategy: PoolingAllocationStrategy, - req: InstanceAllocationRequest, + mut req: InstanceAllocationRequest, ) -> Result { let index = { let mut free_list = self.free_list.lock().unwrap(); @@ -441,17 +441,19 @@ impl InstancePool { free_list.swap_remove(free_index) }; + let host_state = std::mem::replace(&mut req.host_state, Box::new(())); + unsafe { debug_assert!(index < self.max_instances); let instance = &mut *(self.mapping.as_mut_ptr().add(index * self.instance_size) as *mut Instance); - instance.module = req.module; + instance.module = req.module.clone(); instance.offsets = VMOffsets::new( std::mem::size_of::<*const u8>() as u8, instance.module.as_ref(), ); - instance.host_state = req.host_state; + instance.host_state = host_state; Self::set_instance_memories( instance, @@ -460,20 +462,7 @@ impl InstancePool { )?; Self::set_instance_tables(instance, self.tables.get(index), self.tables.max_elements)?; - initialize_vmcontext( - instance, - req.imports.functions, - req.imports.tables, - req.imports.memories, - req.imports.globals, - req.finished_functions, - req.lookup_shared_signature, - req.interrupts, - req.externref_activations_table, - req.stack_map_registry, - &|index| instance.memories[index].vmmemory(), - &|index| instance.tables[index].vmtable(), - ); + initialize_vmcontext(instance, req); Ok(InstanceHandle::new(instance as _)) } @@ -517,6 +506,9 @@ impl InstancePool { decommit(base, size); } } + + // Drop the host state + (*handle.instance).host_state = Box::new(()); } {