From e6f399419c60c0306d13e39c2f955090f7da943f Mon Sep 17 00:00:00 2001 From: bjorn3 Date: Wed, 9 Jun 2021 18:51:11 +0200 Subject: [PATCH] Atomic hotswapping in JIT mode (#2786) * Introduce new_got_entry and new_plt_entry functions * Return NonNull<*const u8> from get_got_address * Make GOT entry writes atomic * Defer GOT updates until relocations and protection Co-authored-by: Alan Egerton --- cranelift/jit/src/backend.rs | 203 +++++++++++++++++------------------ 1 file changed, 98 insertions(+), 105 deletions(-) diff --git a/cranelift/jit/src/backend.rs b/cranelift/jit/src/backend.rs index 4389455d8d..fe55719c44 100644 --- a/cranelift/jit/src/backend.rs +++ b/cranelift/jit/src/backend.rs @@ -23,6 +23,7 @@ use std::ffi::CString; use std::io::Write; use std::ptr; use std::ptr::NonNull; +use std::sync::atomic::{AtomicPtr, Ordering}; use target_lexicon::PointerWidth; #[cfg(windows)] use winapi; @@ -129,6 +130,15 @@ impl JITBuilder { } } +/// A pending update to the GOT. +struct GotUpdate { + /// The entry that is to be updated. + entry: NonNull>, + + /// The new value of the entry. + ptr: *const u8, +} + /// A `JITModule` implements `Module` and emits code and data into memory where it can be /// directly called and accessed. /// @@ -140,15 +150,18 @@ pub struct JITModule { libcall_names: Box String>, memory: MemoryHandle, declarations: ModuleDeclarations, - function_got_entries: SecondaryMap>>, + function_got_entries: SecondaryMap>>>, function_plt_entries: SecondaryMap>>, - data_object_got_entries: SecondaryMap>>, - libcall_got_entries: HashMap>, + data_object_got_entries: SecondaryMap>>>, + libcall_got_entries: HashMap>>, libcall_plt_entries: HashMap>, compiled_functions: SecondaryMap>, compiled_data_objects: SecondaryMap>, functions_to_finalize: Vec, data_objects_to_finalize: Vec, + + /// Updates to the GOT awaiting relocations to be made and region protections to be set + pending_got_updates: Vec, } /// A handle to allow freeing memory allocated by the `Module`. @@ -180,54 +193,53 @@ impl JITModule { .or_else(|| lookup_with_dlsym(name)) } - fn new_func_plt_entry(&mut self, id: FuncId, val: *const u8) { + fn new_got_entry(&mut self, val: *const u8) -> NonNull> { let got_entry = self .memory .writable .allocate( - std::mem::size_of::<*const u8>(), - std::mem::align_of::<*const u8>().try_into().unwrap(), + std::mem::size_of::>(), + std::mem::align_of::>().try_into().unwrap(), ) .unwrap() - .cast::<*const u8>(); - self.function_got_entries[id] = Some(NonNull::new(got_entry).unwrap()); + .cast::>(); unsafe { - std::ptr::write(got_entry, val); + std::ptr::write(got_entry, AtomicPtr::new(val as *mut _)); } + NonNull::new(got_entry).unwrap() + } + + fn new_plt_entry(&mut self, got_entry: NonNull>) -> NonNull<[u8; 16]> { let plt_entry = self .memory .code .allocate(std::mem::size_of::<[u8; 16]>(), EXECUTABLE_DATA_ALIGNMENT) .unwrap() .cast::<[u8; 16]>(); - self.record_function_for_perf( - plt_entry as *mut _, - std::mem::size_of::<[u8; 16]>(), - &format!("{}@plt", self.declarations.get_function_decl(id).name), - ); - self.function_plt_entries[id] = Some(NonNull::new(plt_entry).unwrap()); unsafe { Self::write_plt_entry_bytes(plt_entry, got_entry); } + NonNull::new(plt_entry).unwrap() + } + + fn new_func_plt_entry(&mut self, id: FuncId, val: *const u8) { + let got_entry = self.new_got_entry(val); + self.function_got_entries[id] = Some(got_entry); + let plt_entry = self.new_plt_entry(got_entry); + self.record_function_for_perf( + plt_entry.as_ptr().cast(), + std::mem::size_of::<[u8; 16]>(), + &format!("{}@plt", self.declarations.get_function_decl(id).name), + ); + self.function_plt_entries[id] = Some(plt_entry); } fn new_data_got_entry(&mut self, id: DataId, val: *const u8) { - let got_entry = self - .memory - .writable - .allocate( - std::mem::size_of::<*const u8>(), - std::mem::align_of::<*const u8>().try_into().unwrap(), - ) - .unwrap() - .cast::<*const u8>(); - self.data_object_got_entries[id] = Some(NonNull::new(got_entry).unwrap()); - unsafe { - std::ptr::write(got_entry, val); - } + let got_entry = self.new_got_entry(val); + self.data_object_got_entries[id] = Some(got_entry); } - unsafe fn write_plt_entry_bytes(plt_ptr: *mut [u8; 16], got_ptr: *mut *const u8) { + unsafe fn write_plt_entry_bytes(plt_ptr: *mut [u8; 16], got_ptr: NonNull>) { assert!( cfg!(target_arch = "x86_64"), "PLT is currently only supported on x86_64" @@ -236,7 +248,7 @@ impl JITModule { let mut plt_val = [ 0xff, 0x25, 0, 0, 0, 0, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b, ]; - let what = got_ptr as isize - 4; + let what = got_ptr.as_ptr() as isize - 4; let at = plt_ptr as isize + 2; plt_val[2..6].copy_from_slice(&i32::to_ne_bytes(i32::try_from(what - at).unwrap())); std::ptr::write(plt_ptr, plt_val); @@ -289,32 +301,25 @@ impl JITModule { /// /// Panics if there's no entry in the table for the given function. pub fn read_got_entry(&self, func_id: FuncId) -> *const u8 { - unsafe { *self.function_got_entries[func_id].unwrap().as_ptr() } + let got_entry = self.function_got_entries[func_id].unwrap(); + unsafe { got_entry.as_ref() }.load(Ordering::SeqCst) } - fn get_got_address(&self, name: &ir::ExternalName) -> *const u8 { + fn get_got_address(&self, name: &ir::ExternalName) -> NonNull> { match *name { ir::ExternalName::User { .. } => { if ModuleDeclarations::is_function(name) { let func_id = FuncId::from_name(name); - self.function_got_entries[func_id] - .unwrap() - .as_ptr() - .cast::() + self.function_got_entries[func_id].unwrap() } else { let data_id = DataId::from_name(name); - self.data_object_got_entries[data_id] - .unwrap() - .as_ptr() - .cast::() + self.data_object_got_entries[data_id].unwrap() } } - ir::ExternalName::LibCall(ref libcall) => self + ir::ExternalName::LibCall(ref libcall) => *self .libcall_got_entries .get(libcall) - .unwrap_or_else(|| panic!("can't resolve libcall {}", libcall)) - .as_ptr() - .cast::(), + .unwrap_or_else(|| panic!("can't resolve libcall {}", libcall)), _ => panic!("invalid ExternalName {}", name), } } @@ -406,7 +411,7 @@ impl JITModule { .expect("function must be compiled before it can be finalized"); func.perform_relocations( |name| self.get_address(name), - |name| self.get_got_address(name), + |name| self.get_got_address(name).as_ptr().cast(), |name| self.get_plt_address(name), ); } @@ -419,7 +424,7 @@ impl JITModule { .expect("data object must be compiled before it can be finalized"); data.perform_relocations( |name| self.get_address(name), - |name| self.get_got_address(name), + |name| self.get_got_address(name).as_ptr().cast(), |name| self.get_plt_address(name), ); } @@ -427,6 +432,10 @@ impl JITModule { // Now that we're done patching, prepare the memory for execution! self.memory.readonly.set_readonly(); self.memory.code.set_readable_and_executable(); + + for update in self.pending_got_updates.drain(..) { + unsafe { update.entry.as_ref() }.store(update.ptr as *mut _, Ordering::SeqCst); + } } /// Create a new `JITModule`. @@ -438,33 +447,38 @@ impl JITModule { ); } - let mut memory = MemoryHandle { - code: Memory::new(), - readonly: Memory::new(), - writable: Memory::new(), + let mut module = Self { + isa: builder.isa, + hotswap_enabled: builder.hotswap_enabled, + symbols: builder.symbols, + libcall_names: builder.libcall_names, + memory: MemoryHandle { + code: Memory::new(), + readonly: Memory::new(), + writable: Memory::new(), + }, + declarations: ModuleDeclarations::default(), + function_got_entries: SecondaryMap::new(), + function_plt_entries: SecondaryMap::new(), + data_object_got_entries: SecondaryMap::new(), + libcall_got_entries: HashMap::new(), + libcall_plt_entries: HashMap::new(), + compiled_functions: SecondaryMap::new(), + compiled_data_objects: SecondaryMap::new(), + functions_to_finalize: Vec::new(), + data_objects_to_finalize: Vec::new(), + pending_got_updates: Vec::new(), }; - let mut libcall_got_entries = HashMap::new(); - let mut libcall_plt_entries = HashMap::new(); - // Pre-create a GOT and PLT entry for each libcall. - let all_libcalls = if builder.isa.flags().is_pic() { + let all_libcalls = if module.isa.flags().is_pic() { ir::LibCall::all_libcalls() } else { &[] // Not PIC, so no GOT and PLT entries necessary }; for &libcall in all_libcalls { - let got_entry = memory - .writable - .allocate( - std::mem::size_of::<*const u8>(), - std::mem::align_of::<*const u8>().try_into().unwrap(), - ) - .unwrap() - .cast::<*const u8>(); - libcall_got_entries.insert(libcall, NonNull::new(got_entry).unwrap()); - let sym = (builder.libcall_names)(libcall); - let addr = if let Some(addr) = builder + let sym = (module.libcall_names)(libcall); + let addr = if let Some(addr) = module .symbols .get(&sym) .copied() @@ -474,37 +488,13 @@ impl JITModule { } else { continue; }; - unsafe { - std::ptr::write(got_entry, addr); - } - let plt_entry = memory - .code - .allocate(std::mem::size_of::<[u8; 16]>(), EXECUTABLE_DATA_ALIGNMENT) - .unwrap() - .cast::<[u8; 16]>(); - libcall_plt_entries.insert(libcall, NonNull::new(plt_entry).unwrap()); - unsafe { - Self::write_plt_entry_bytes(plt_entry, got_entry); - } + let got_entry = module.new_got_entry(addr); + module.libcall_got_entries.insert(libcall, got_entry); + let plt_entry = module.new_plt_entry(got_entry); + module.libcall_plt_entries.insert(libcall, plt_entry); } - Self { - isa: builder.isa, - hotswap_enabled: builder.hotswap_enabled, - symbols: builder.symbols, - libcall_names: builder.libcall_names, - memory, - declarations: ModuleDeclarations::default(), - function_got_entries: SecondaryMap::new(), - function_plt_entries: SecondaryMap::new(), - data_object_got_entries: SecondaryMap::new(), - libcall_got_entries, - libcall_plt_entries, - compiled_functions: SecondaryMap::new(), - compiled_data_objects: SecondaryMap::new(), - functions_to_finalize: Vec::new(), - data_objects_to_finalize: Vec::new(), - } + module } /// Allow a single future `define_function` on a previously defined function. This allows for @@ -682,9 +672,10 @@ impl Module for JITModule { }); if self.isa.flags().is_pic() { - unsafe { - std::ptr::write(self.function_got_entries[id].unwrap().as_ptr(), ptr); - } + self.pending_got_updates.push(GotUpdate { + entry: self.function_got_entries[id].unwrap(), + ptr, + }) } if self.hotswap_enabled { @@ -704,7 +695,7 @@ impl Module for JITModule { .cast::(), _ => panic!("invalid ExternalName {}", name), }, - |name| self.get_got_address(name), + |name| self.get_got_address(name).as_ptr().cast(), |name| self.get_plt_address(name), ); } else { @@ -754,9 +745,10 @@ impl Module for JITModule { }); if self.isa.flags().is_pic() { - unsafe { - std::ptr::write(self.function_got_entries[id].unwrap().as_ptr(), ptr); - } + self.pending_got_updates.push(GotUpdate { + entry: self.function_got_entries[id].unwrap(), + ptr, + }) } if self.hotswap_enabled { @@ -765,7 +757,7 @@ impl Module for JITModule { .unwrap() .perform_relocations( |name| unreachable!("non GOT or PLT relocation in function {} to {}", id, name), - |name| self.get_got_address(name), + |name| self.get_got_address(name).as_ptr().cast(), |name| self.get_plt_address(name), ); } else { @@ -836,9 +828,10 @@ impl Module for JITModule { self.compiled_data_objects[id] = Some(CompiledBlob { ptr, size, relocs }); self.data_objects_to_finalize.push(id); if self.isa.flags().is_pic() { - unsafe { - std::ptr::write(self.data_object_got_entries[id].unwrap().as_ptr(), ptr); - } + self.pending_got_updates.push(GotUpdate { + entry: self.data_object_got_entries[id].unwrap(), + ptr, + }) } Ok(())