Atomic hotswapping in JIT mode (#2786)

* Introduce new_got_entry and new_plt_entry functions

* Return NonNull<*const u8> from get_got_address

* Make GOT entry writes atomic

* Defer GOT updates until relocations and protection

Co-authored-by: Alan Egerton <eggyal@gmail.com>
This commit is contained in:
bjorn3
2021-06-09 18:51:11 +02:00
committed by GitHub
parent 884a6500e9
commit e6f399419c

View File

@@ -23,6 +23,7 @@ use std::ffi::CString;
use std::io::Write; use std::io::Write;
use std::ptr; use std::ptr;
use std::ptr::NonNull; use std::ptr::NonNull;
use std::sync::atomic::{AtomicPtr, Ordering};
use target_lexicon::PointerWidth; use target_lexicon::PointerWidth;
#[cfg(windows)] #[cfg(windows)]
use winapi; use winapi;
@@ -129,6 +130,15 @@ impl JITBuilder {
} }
} }
/// A pending update to the GOT.
struct GotUpdate {
/// The entry that is to be updated.
entry: NonNull<AtomicPtr<u8>>,
/// The new value of the entry.
ptr: *const u8,
}
/// A `JITModule` implements `Module` and emits code and data into memory where it can be /// A `JITModule` implements `Module` and emits code and data into memory where it can be
/// directly called and accessed. /// directly called and accessed.
/// ///
@@ -140,15 +150,18 @@ pub struct JITModule {
libcall_names: Box<dyn Fn(ir::LibCall) -> String>, libcall_names: Box<dyn Fn(ir::LibCall) -> String>,
memory: MemoryHandle, memory: MemoryHandle,
declarations: ModuleDeclarations, declarations: ModuleDeclarations,
function_got_entries: SecondaryMap<FuncId, Option<NonNull<*const u8>>>, function_got_entries: SecondaryMap<FuncId, Option<NonNull<AtomicPtr<u8>>>>,
function_plt_entries: SecondaryMap<FuncId, Option<NonNull<[u8; 16]>>>, function_plt_entries: SecondaryMap<FuncId, Option<NonNull<[u8; 16]>>>,
data_object_got_entries: SecondaryMap<DataId, Option<NonNull<*const u8>>>, data_object_got_entries: SecondaryMap<DataId, Option<NonNull<AtomicPtr<u8>>>>,
libcall_got_entries: HashMap<ir::LibCall, NonNull<*const u8>>, libcall_got_entries: HashMap<ir::LibCall, NonNull<AtomicPtr<u8>>>,
libcall_plt_entries: HashMap<ir::LibCall, NonNull<[u8; 16]>>, libcall_plt_entries: HashMap<ir::LibCall, NonNull<[u8; 16]>>,
compiled_functions: SecondaryMap<FuncId, Option<CompiledBlob>>, compiled_functions: SecondaryMap<FuncId, Option<CompiledBlob>>,
compiled_data_objects: SecondaryMap<DataId, Option<CompiledBlob>>, compiled_data_objects: SecondaryMap<DataId, Option<CompiledBlob>>,
functions_to_finalize: Vec<FuncId>, functions_to_finalize: Vec<FuncId>,
data_objects_to_finalize: Vec<DataId>, data_objects_to_finalize: Vec<DataId>,
/// Updates to the GOT awaiting relocations to be made and region protections to be set
pending_got_updates: Vec<GotUpdate>,
} }
/// A handle to allow freeing memory allocated by the `Module`. /// A handle to allow freeing memory allocated by the `Module`.
@@ -180,54 +193,53 @@ impl JITModule {
.or_else(|| lookup_with_dlsym(name)) .or_else(|| lookup_with_dlsym(name))
} }
fn new_func_plt_entry(&mut self, id: FuncId, val: *const u8) { fn new_got_entry(&mut self, val: *const u8) -> NonNull<AtomicPtr<u8>> {
let got_entry = self let got_entry = self
.memory .memory
.writable .writable
.allocate( .allocate(
std::mem::size_of::<*const u8>(), std::mem::size_of::<AtomicPtr<u8>>(),
std::mem::align_of::<*const u8>().try_into().unwrap(), std::mem::align_of::<AtomicPtr<u8>>().try_into().unwrap(),
) )
.unwrap() .unwrap()
.cast::<*const u8>(); .cast::<AtomicPtr<u8>>();
self.function_got_entries[id] = Some(NonNull::new(got_entry).unwrap());
unsafe { unsafe {
std::ptr::write(got_entry, val); std::ptr::write(got_entry, AtomicPtr::new(val as *mut _));
} }
NonNull::new(got_entry).unwrap()
}
fn new_plt_entry(&mut self, got_entry: NonNull<AtomicPtr<u8>>) -> NonNull<[u8; 16]> {
let plt_entry = self let plt_entry = self
.memory .memory
.code .code
.allocate(std::mem::size_of::<[u8; 16]>(), EXECUTABLE_DATA_ALIGNMENT) .allocate(std::mem::size_of::<[u8; 16]>(), EXECUTABLE_DATA_ALIGNMENT)
.unwrap() .unwrap()
.cast::<[u8; 16]>(); .cast::<[u8; 16]>();
self.record_function_for_perf(
plt_entry as *mut _,
std::mem::size_of::<[u8; 16]>(),
&format!("{}@plt", self.declarations.get_function_decl(id).name),
);
self.function_plt_entries[id] = Some(NonNull::new(plt_entry).unwrap());
unsafe { unsafe {
Self::write_plt_entry_bytes(plt_entry, got_entry); Self::write_plt_entry_bytes(plt_entry, got_entry);
} }
NonNull::new(plt_entry).unwrap()
}
fn new_func_plt_entry(&mut self, id: FuncId, val: *const u8) {
let got_entry = self.new_got_entry(val);
self.function_got_entries[id] = Some(got_entry);
let plt_entry = self.new_plt_entry(got_entry);
self.record_function_for_perf(
plt_entry.as_ptr().cast(),
std::mem::size_of::<[u8; 16]>(),
&format!("{}@plt", self.declarations.get_function_decl(id).name),
);
self.function_plt_entries[id] = Some(plt_entry);
} }
fn new_data_got_entry(&mut self, id: DataId, val: *const u8) { fn new_data_got_entry(&mut self, id: DataId, val: *const u8) {
let got_entry = self let got_entry = self.new_got_entry(val);
.memory self.data_object_got_entries[id] = Some(got_entry);
.writable
.allocate(
std::mem::size_of::<*const u8>(),
std::mem::align_of::<*const u8>().try_into().unwrap(),
)
.unwrap()
.cast::<*const u8>();
self.data_object_got_entries[id] = Some(NonNull::new(got_entry).unwrap());
unsafe {
std::ptr::write(got_entry, val);
}
} }
unsafe fn write_plt_entry_bytes(plt_ptr: *mut [u8; 16], got_ptr: *mut *const u8) { unsafe fn write_plt_entry_bytes(plt_ptr: *mut [u8; 16], got_ptr: NonNull<AtomicPtr<u8>>) {
assert!( assert!(
cfg!(target_arch = "x86_64"), cfg!(target_arch = "x86_64"),
"PLT is currently only supported on x86_64" "PLT is currently only supported on x86_64"
@@ -236,7 +248,7 @@ impl JITModule {
let mut plt_val = [ let mut plt_val = [
0xff, 0x25, 0, 0, 0, 0, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b, 0xff, 0x25, 0, 0, 0, 0, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b,
]; ];
let what = got_ptr as isize - 4; let what = got_ptr.as_ptr() as isize - 4;
let at = plt_ptr as isize + 2; let at = plt_ptr as isize + 2;
plt_val[2..6].copy_from_slice(&i32::to_ne_bytes(i32::try_from(what - at).unwrap())); plt_val[2..6].copy_from_slice(&i32::to_ne_bytes(i32::try_from(what - at).unwrap()));
std::ptr::write(plt_ptr, plt_val); std::ptr::write(plt_ptr, plt_val);
@@ -289,32 +301,25 @@ impl JITModule {
/// ///
/// Panics if there's no entry in the table for the given function. /// Panics if there's no entry in the table for the given function.
pub fn read_got_entry(&self, func_id: FuncId) -> *const u8 { pub fn read_got_entry(&self, func_id: FuncId) -> *const u8 {
unsafe { *self.function_got_entries[func_id].unwrap().as_ptr() } let got_entry = self.function_got_entries[func_id].unwrap();
unsafe { got_entry.as_ref() }.load(Ordering::SeqCst)
} }
fn get_got_address(&self, name: &ir::ExternalName) -> *const u8 { fn get_got_address(&self, name: &ir::ExternalName) -> NonNull<AtomicPtr<u8>> {
match *name { match *name {
ir::ExternalName::User { .. } => { ir::ExternalName::User { .. } => {
if ModuleDeclarations::is_function(name) { if ModuleDeclarations::is_function(name) {
let func_id = FuncId::from_name(name); let func_id = FuncId::from_name(name);
self.function_got_entries[func_id] self.function_got_entries[func_id].unwrap()
.unwrap()
.as_ptr()
.cast::<u8>()
} else { } else {
let data_id = DataId::from_name(name); let data_id = DataId::from_name(name);
self.data_object_got_entries[data_id] self.data_object_got_entries[data_id].unwrap()
.unwrap()
.as_ptr()
.cast::<u8>()
} }
} }
ir::ExternalName::LibCall(ref libcall) => self ir::ExternalName::LibCall(ref libcall) => *self
.libcall_got_entries .libcall_got_entries
.get(libcall) .get(libcall)
.unwrap_or_else(|| panic!("can't resolve libcall {}", libcall)) .unwrap_or_else(|| panic!("can't resolve libcall {}", libcall)),
.as_ptr()
.cast::<u8>(),
_ => panic!("invalid ExternalName {}", name), _ => panic!("invalid ExternalName {}", name),
} }
} }
@@ -406,7 +411,7 @@ impl JITModule {
.expect("function must be compiled before it can be finalized"); .expect("function must be compiled before it can be finalized");
func.perform_relocations( func.perform_relocations(
|name| self.get_address(name), |name| self.get_address(name),
|name| self.get_got_address(name), |name| self.get_got_address(name).as_ptr().cast(),
|name| self.get_plt_address(name), |name| self.get_plt_address(name),
); );
} }
@@ -419,7 +424,7 @@ impl JITModule {
.expect("data object must be compiled before it can be finalized"); .expect("data object must be compiled before it can be finalized");
data.perform_relocations( data.perform_relocations(
|name| self.get_address(name), |name| self.get_address(name),
|name| self.get_got_address(name), |name| self.get_got_address(name).as_ptr().cast(),
|name| self.get_plt_address(name), |name| self.get_plt_address(name),
); );
} }
@@ -427,6 +432,10 @@ impl JITModule {
// Now that we're done patching, prepare the memory for execution! // Now that we're done patching, prepare the memory for execution!
self.memory.readonly.set_readonly(); self.memory.readonly.set_readonly();
self.memory.code.set_readable_and_executable(); self.memory.code.set_readable_and_executable();
for update in self.pending_got_updates.drain(..) {
unsafe { update.entry.as_ref() }.store(update.ptr as *mut _, Ordering::SeqCst);
}
} }
/// Create a new `JITModule`. /// Create a new `JITModule`.
@@ -438,33 +447,38 @@ impl JITModule {
); );
} }
let mut memory = MemoryHandle { let mut module = Self {
code: Memory::new(), isa: builder.isa,
readonly: Memory::new(), hotswap_enabled: builder.hotswap_enabled,
writable: Memory::new(), symbols: builder.symbols,
libcall_names: builder.libcall_names,
memory: MemoryHandle {
code: Memory::new(),
readonly: Memory::new(),
writable: Memory::new(),
},
declarations: ModuleDeclarations::default(),
function_got_entries: SecondaryMap::new(),
function_plt_entries: SecondaryMap::new(),
data_object_got_entries: SecondaryMap::new(),
libcall_got_entries: HashMap::new(),
libcall_plt_entries: HashMap::new(),
compiled_functions: SecondaryMap::new(),
compiled_data_objects: SecondaryMap::new(),
functions_to_finalize: Vec::new(),
data_objects_to_finalize: Vec::new(),
pending_got_updates: Vec::new(),
}; };
let mut libcall_got_entries = HashMap::new();
let mut libcall_plt_entries = HashMap::new();
// Pre-create a GOT and PLT entry for each libcall. // Pre-create a GOT and PLT entry for each libcall.
let all_libcalls = if builder.isa.flags().is_pic() { let all_libcalls = if module.isa.flags().is_pic() {
ir::LibCall::all_libcalls() ir::LibCall::all_libcalls()
} else { } else {
&[] // Not PIC, so no GOT and PLT entries necessary &[] // Not PIC, so no GOT and PLT entries necessary
}; };
for &libcall in all_libcalls { for &libcall in all_libcalls {
let got_entry = memory let sym = (module.libcall_names)(libcall);
.writable let addr = if let Some(addr) = module
.allocate(
std::mem::size_of::<*const u8>(),
std::mem::align_of::<*const u8>().try_into().unwrap(),
)
.unwrap()
.cast::<*const u8>();
libcall_got_entries.insert(libcall, NonNull::new(got_entry).unwrap());
let sym = (builder.libcall_names)(libcall);
let addr = if let Some(addr) = builder
.symbols .symbols
.get(&sym) .get(&sym)
.copied() .copied()
@@ -474,37 +488,13 @@ impl JITModule {
} else { } else {
continue; continue;
}; };
unsafe { let got_entry = module.new_got_entry(addr);
std::ptr::write(got_entry, addr); module.libcall_got_entries.insert(libcall, got_entry);
} let plt_entry = module.new_plt_entry(got_entry);
let plt_entry = memory module.libcall_plt_entries.insert(libcall, plt_entry);
.code
.allocate(std::mem::size_of::<[u8; 16]>(), EXECUTABLE_DATA_ALIGNMENT)
.unwrap()
.cast::<[u8; 16]>();
libcall_plt_entries.insert(libcall, NonNull::new(plt_entry).unwrap());
unsafe {
Self::write_plt_entry_bytes(plt_entry, got_entry);
}
} }
Self { module
isa: builder.isa,
hotswap_enabled: builder.hotswap_enabled,
symbols: builder.symbols,
libcall_names: builder.libcall_names,
memory,
declarations: ModuleDeclarations::default(),
function_got_entries: SecondaryMap::new(),
function_plt_entries: SecondaryMap::new(),
data_object_got_entries: SecondaryMap::new(),
libcall_got_entries,
libcall_plt_entries,
compiled_functions: SecondaryMap::new(),
compiled_data_objects: SecondaryMap::new(),
functions_to_finalize: Vec::new(),
data_objects_to_finalize: Vec::new(),
}
} }
/// Allow a single future `define_function` on a previously defined function. This allows for /// Allow a single future `define_function` on a previously defined function. This allows for
@@ -682,9 +672,10 @@ impl Module for JITModule {
}); });
if self.isa.flags().is_pic() { if self.isa.flags().is_pic() {
unsafe { self.pending_got_updates.push(GotUpdate {
std::ptr::write(self.function_got_entries[id].unwrap().as_ptr(), ptr); entry: self.function_got_entries[id].unwrap(),
} ptr,
})
} }
if self.hotswap_enabled { if self.hotswap_enabled {
@@ -704,7 +695,7 @@ impl Module for JITModule {
.cast::<u8>(), .cast::<u8>(),
_ => panic!("invalid ExternalName {}", name), _ => panic!("invalid ExternalName {}", name),
}, },
|name| self.get_got_address(name), |name| self.get_got_address(name).as_ptr().cast(),
|name| self.get_plt_address(name), |name| self.get_plt_address(name),
); );
} else { } else {
@@ -754,9 +745,10 @@ impl Module for JITModule {
}); });
if self.isa.flags().is_pic() { if self.isa.flags().is_pic() {
unsafe { self.pending_got_updates.push(GotUpdate {
std::ptr::write(self.function_got_entries[id].unwrap().as_ptr(), ptr); entry: self.function_got_entries[id].unwrap(),
} ptr,
})
} }
if self.hotswap_enabled { if self.hotswap_enabled {
@@ -765,7 +757,7 @@ impl Module for JITModule {
.unwrap() .unwrap()
.perform_relocations( .perform_relocations(
|name| unreachable!("non GOT or PLT relocation in function {} to {}", id, name), |name| unreachable!("non GOT or PLT relocation in function {} to {}", id, name),
|name| self.get_got_address(name), |name| self.get_got_address(name).as_ptr().cast(),
|name| self.get_plt_address(name), |name| self.get_plt_address(name),
); );
} else { } else {
@@ -836,9 +828,10 @@ impl Module for JITModule {
self.compiled_data_objects[id] = Some(CompiledBlob { ptr, size, relocs }); self.compiled_data_objects[id] = Some(CompiledBlob { ptr, size, relocs });
self.data_objects_to_finalize.push(id); self.data_objects_to_finalize.push(id);
if self.isa.flags().is_pic() { if self.isa.flags().is_pic() {
unsafe { self.pending_got_updates.push(GotUpdate {
std::ptr::write(self.data_object_got_entries[id].unwrap().as_ptr(), ptr); entry: self.data_object_got_entries[id].unwrap(),
} ptr,
})
} }
Ok(()) Ok(())