Merge pull request #2287 from bjorn3/simplejit_improvements

Some SimpleJIT improvements
This commit is contained in:
Pat Hickey
2020-10-29 12:09:37 -07:00
committed by GitHub
2 changed files with 110 additions and 89 deletions

View File

@@ -122,6 +122,7 @@ impl From<FuncOrDataId> for ir::ExternalName {
} }
/// Information about a function which can be called. /// Information about a function which can be called.
#[derive(Debug)]
pub struct FunctionDeclaration { pub struct FunctionDeclaration {
pub name: String, pub name: String,
pub linkage: Linkage, pub linkage: Linkage,
@@ -176,6 +177,7 @@ pub enum ModuleError {
pub type ModuleResult<T> = Result<T, ModuleError>; pub type ModuleResult<T> = Result<T, ModuleError>;
/// Information about a data object which can be accessed. /// Information about a data object which can be accessed.
#[derive(Debug)]
pub struct DataDeclaration { pub struct DataDeclaration {
pub name: String, pub name: String,
pub linkage: Linkage, pub linkage: Linkage,
@@ -196,7 +198,7 @@ impl DataDeclaration {
/// This provides a view to the state of a module which allows `ir::ExternalName`s to be translated /// This provides a view to the state of a module which allows `ir::ExternalName`s to be translated
/// into `FunctionDeclaration`s and `DataDeclaration`s. /// into `FunctionDeclaration`s and `DataDeclaration`s.
#[derive(Default)] #[derive(Debug, Default)]
pub struct ModuleDeclarations { pub struct ModuleDeclarations {
names: HashMap<String, FuncOrDataId>, names: HashMap<String, FuncOrDataId>,
functions: PrimaryMap<FuncId, FunctionDeclaration>, functions: PrimaryMap<FuncId, FunctionDeclaration>,

View File

@@ -126,10 +126,10 @@ pub struct SimpleJITModule {
isa: Box<dyn TargetIsa>, isa: Box<dyn TargetIsa>,
symbols: HashMap<String, *const u8>, symbols: HashMap<String, *const u8>,
libcall_names: Box<dyn Fn(ir::LibCall) -> String>, libcall_names: Box<dyn Fn(ir::LibCall) -> String>,
memory: SimpleJITMemoryHandle, memory: MemoryHandle,
declarations: ModuleDeclarations, declarations: ModuleDeclarations,
functions: SecondaryMap<FuncId, Option<SimpleJITCompiledFunction>>, functions: SecondaryMap<FuncId, Option<CompiledBlob>>,
data_objects: SecondaryMap<DataId, Option<SimpleJITCompiledData>>, data_objects: SecondaryMap<DataId, Option<CompiledBlob>>,
functions_to_finalize: Vec<FuncId>, functions_to_finalize: Vec<FuncId>,
data_objects_to_finalize: Vec<DataId>, data_objects_to_finalize: Vec<DataId>,
} }
@@ -151,21 +151,14 @@ struct StackMapRecord {
} }
#[derive(Clone)] #[derive(Clone)]
pub struct SimpleJITCompiledFunction { struct CompiledBlob {
code: *mut u8, ptr: *mut u8,
size: usize,
relocs: Vec<RelocRecord>,
}
#[derive(Clone)]
pub struct SimpleJITCompiledData {
storage: *mut u8,
size: usize, size: usize,
relocs: Vec<RelocRecord>, relocs: Vec<RelocRecord>,
} }
/// A handle to allow freeing memory allocated by the `Module`. /// A handle to allow freeing memory allocated by the `Module`.
struct SimpleJITMemoryHandle { struct MemoryHandle {
code: Memory, code: Memory,
readonly: Memory, readonly: Memory,
writable: Memory, writable: Memory,
@@ -174,10 +167,10 @@ struct SimpleJITMemoryHandle {
/// A `SimpleJITProduct` allows looking up the addresses of all functions and data objects /// A `SimpleJITProduct` allows looking up the addresses of all functions and data objects
/// defined in the original module. /// defined in the original module.
pub struct SimpleJITProduct { pub struct SimpleJITProduct {
memory: SimpleJITMemoryHandle, memory: MemoryHandle,
declarations: ModuleDeclarations, declarations: ModuleDeclarations,
functions: SecondaryMap<FuncId, Option<SimpleJITCompiledFunction>>, functions: SecondaryMap<FuncId, Option<CompiledBlob>>,
data_objects: SecondaryMap<DataId, Option<SimpleJITCompiledData>>, data_objects: SecondaryMap<DataId, Option<CompiledBlob>>,
} }
impl SimpleJITProduct { impl SimpleJITProduct {
@@ -205,7 +198,7 @@ impl SimpleJITProduct {
self.functions[func_id] self.functions[func_id]
.as_ref() .as_ref()
.unwrap_or_else(|| panic!("{} is not defined", func_id)) .unwrap_or_else(|| panic!("{} is not defined", func_id))
.code .ptr
} }
/// Return the address and size of a data object. /// Return the address and size of a data object.
@@ -213,45 +206,83 @@ impl SimpleJITProduct {
let data = self.data_objects[data_id] let data = self.data_objects[data_id]
.as_ref() .as_ref()
.unwrap_or_else(|| panic!("{} is not defined", data_id)); .unwrap_or_else(|| panic!("{} is not defined", data_id));
(data.storage, data.size) (data.ptr, data.size)
} }
} }
impl SimpleJITModule { impl SimpleJITModule {
fn lookup_symbol(&self, name: &str) -> *const u8 { fn lookup_symbol(&self, name: &str) -> Option<*const u8> {
match self.symbols.get(name) { self.symbols
Some(&ptr) => ptr, .get(name)
None => lookup_with_dlsym(name), .copied()
} .or_else(|| lookup_with_dlsym(name))
} }
fn get_definition(&self, name: &ir::ExternalName) -> *const u8 { fn get_definition(&self, name: &ir::ExternalName) -> *const u8 {
match *name { match *name {
ir::ExternalName::User { .. } => { ir::ExternalName::User { .. } => {
if self.declarations.is_function(name) { let (name, linkage) = if self.declarations.is_function(name) {
let func_id = self.declarations.get_function_id(name); let func_id = self.declarations.get_function_id(name);
match &self.functions[func_id] { match &self.functions[func_id] {
Some(compiled) => compiled.code, Some(compiled) => return compiled.ptr,
None => { None => {
self.lookup_symbol(&self.declarations.get_function_decl(func_id).name) let decl = self.declarations.get_function_decl(func_id);
(&decl.name, decl.linkage)
} }
} }
} else { } else {
let data_id = self.declarations.get_data_id(name); let data_id = self.declarations.get_data_id(name);
match &self.data_objects[data_id] { match &self.data_objects[data_id] {
Some(compiled) => compiled.storage, Some(compiled) => return compiled.ptr,
None => self.lookup_symbol(&self.declarations.get_data_decl(data_id).name), None => {
let decl = self.declarations.get_data_decl(data_id);
(&decl.name, decl.linkage)
} }
} }
};
if let Some(ptr) = self.lookup_symbol(&name) {
ptr
} else if linkage == Linkage::Preemptible {
0 as *const u8
} else {
panic!("can't resolve symbol {}", name);
}
} }
ir::ExternalName::LibCall(ref libcall) => { ir::ExternalName::LibCall(ref libcall) => {
let sym = (self.libcall_names)(*libcall); let sym = (self.libcall_names)(*libcall);
self.lookup_symbol(&sym) self.lookup_symbol(&sym)
.unwrap_or_else(|| panic!("can't resolve libcall {}", sym))
} }
_ => panic!("invalid ExternalName {}", name), _ => panic!("invalid ExternalName {}", name),
} }
} }
/// Returns the address of a finalized function.
pub fn get_finalized_function(&self, func_id: FuncId) -> *const u8 {
let info = &self.functions[func_id];
debug_assert!(
!self.functions_to_finalize.iter().any(|x| *x == func_id),
"function not yet finalized"
);
info.as_ref()
.expect("function must be compiled before it can be finalized")
.ptr
}
/// Returns the address and size of a finalized data object.
pub fn get_finalized_data(&self, data_id: DataId) -> (*const u8, usize) {
let info = &self.data_objects[data_id];
debug_assert!(
!self.data_objects_to_finalize.iter().any(|x| *x == data_id),
"data object not yet finalized"
);
let compiled = info
.as_ref()
.expect("data object must be compiled before it can be finalized");
(compiled.ptr, compiled.size)
}
fn record_function_for_perf(&self, ptr: *mut u8, size: usize, name: &str) { fn record_function_for_perf(&self, ptr: *mut u8, size: usize, name: &str) {
// The Linux perf tool supports JIT code via a /tmp/perf-$PID.map file, // The Linux perf tool supports JIT code via a /tmp/perf-$PID.map file,
// which contains memory regions and their associated names. If we // which contains memory regions and their associated names. If we
@@ -283,9 +314,8 @@ impl SimpleJITModule {
addend, addend,
} in &func.relocs } in &func.relocs
{ {
let ptr = func.code;
debug_assert!((offset as usize) < func.size); debug_assert!((offset as usize) < func.size);
let at = unsafe { ptr.offset(offset as isize) }; let at = unsafe { func.ptr.offset(offset as isize) };
let base = self.get_definition(name); let base = self.get_definition(name);
// TODO: Handle overflow. // TODO: Handle overflow.
let what = unsafe { base.offset(addend as isize) }; let what = unsafe { base.offset(addend as isize) };
@@ -331,9 +361,8 @@ impl SimpleJITModule {
addend, addend,
} in &data.relocs } in &data.relocs
{ {
let ptr = data.storage;
debug_assert!((offset as usize) < data.size); debug_assert!((offset as usize) < data.size);
let at = unsafe { ptr.offset(offset as isize) }; let at = unsafe { data.ptr.offset(offset as isize) };
let base = self.get_definition(name); let base = self.get_definition(name);
// TODO: Handle overflow. // TODO: Handle overflow.
let what = unsafe { base.offset(addend as isize) }; let what = unsafe { base.offset(addend as isize) };
@@ -360,9 +389,32 @@ impl SimpleJITModule {
} }
} }
/// Finalize all functions and data objects that are defined but not yet finalized.
/// All symbols referenced in their bodies that are declared as needing a definition
/// must be defined by this point.
///
/// Use `get_finalized_function` and `get_finalized_data` to obtain the final
/// artifacts.
pub fn finalize_definitions(&mut self) {
for func in std::mem::take(&mut self.functions_to_finalize) {
let decl = self.declarations.get_function_decl(func);
debug_assert!(decl.linkage.is_definable());
self.finalize_function(func);
}
for data in std::mem::take(&mut self.data_objects_to_finalize) {
let decl = self.declarations.get_data_decl(data);
debug_assert!(decl.linkage.is_definable());
self.finalize_data(data);
}
// Now that we're done patching, prepare the memory for execution!
self.memory.readonly.set_readonly();
self.memory.code.set_readable_and_executable();
}
/// Create a new `SimpleJITModule`. /// Create a new `SimpleJITModule`.
pub fn new(builder: SimpleJITBuilder) -> Self { pub fn new(builder: SimpleJITBuilder) -> Self {
let memory = SimpleJITMemoryHandle { let memory = MemoryHandle {
code: Memory::new(), code: Memory::new(),
readonly: Memory::new(), readonly: Memory::new(),
writable: Memory::new(), writable: Memory::new(),
@@ -451,8 +503,8 @@ impl<'simple_jit_backend> Module for SimpleJITModule {
self.record_function_for_perf(ptr, size, &decl.name); self.record_function_for_perf(ptr, size, &decl.name);
let mut reloc_sink = SimpleJITRelocSink::new(); let mut reloc_sink = SimpleJITRelocSink::default();
let mut stack_map_sink = SimpleJITStackMapSink::new(); let mut stack_map_sink = SimpleJITStackMapSink::default();
unsafe { unsafe {
ctx.emit_to_memory( ctx.emit_to_memory(
&*self.isa, &*self.isa,
@@ -463,8 +515,8 @@ impl<'simple_jit_backend> Module for SimpleJITModule {
) )
}; };
self.functions[id] = Some(SimpleJITCompiledFunction { self.functions[id] = Some(CompiledBlob {
code: ptr, ptr,
size, size,
relocs: reloc_sink.relocs, relocs: reloc_sink.relocs,
}); });
@@ -505,8 +557,8 @@ impl<'simple_jit_backend> Module for SimpleJITModule {
ptr::copy_nonoverlapping(bytes.as_ptr(), ptr, size); ptr::copy_nonoverlapping(bytes.as_ptr(), ptr, size);
} }
self.functions[id] = Some(SimpleJITCompiledFunction { self.functions[id] = Some(CompiledBlob {
code: ptr, ptr,
size, size,
relocs: vec![], relocs: vec![],
}); });
@@ -539,7 +591,7 @@ impl<'simple_jit_backend> Module for SimpleJITModule {
} = data.description(); } = data.description();
let size = init.size(); let size = init.size();
let storage = if decl.writable { let ptr = if decl.writable {
self.memory self.memory
.writable .writable
.allocate(size, align.unwrap_or(WRITABLE_DATA_ALIGNMENT)) .allocate(size, align.unwrap_or(WRITABLE_DATA_ALIGNMENT))
@@ -556,11 +608,11 @@ impl<'simple_jit_backend> Module for SimpleJITModule {
panic!("data is not initialized yet"); panic!("data is not initialized yet");
} }
Init::Zeros { .. } => { Init::Zeros { .. } => {
unsafe { ptr::write_bytes(storage, 0, size) }; unsafe { ptr::write_bytes(ptr, 0, size) };
} }
Init::Bytes { ref contents } => { Init::Bytes { ref contents } => {
let src = contents.as_ptr(); let src = contents.as_ptr();
unsafe { ptr::copy_nonoverlapping(src, storage, size) }; unsafe { ptr::copy_nonoverlapping(src, ptr, size) };
} }
} }
@@ -587,11 +639,7 @@ impl<'simple_jit_backend> Module for SimpleJITModule {
}); });
} }
self.data_objects[id] = Some(SimpleJITCompiledData { self.data_objects[id] = Some(CompiledBlob { ptr, size, relocs });
storage,
size,
relocs,
});
Ok(()) Ok(())
} }
@@ -606,20 +654,7 @@ impl SimpleJITModule {
/// This method does not need to be called when access to the memory /// This method does not need to be called when access to the memory
/// handle is not required. /// handle is not required.
pub fn finish(mut self) -> SimpleJITProduct { pub fn finish(mut self) -> SimpleJITProduct {
for func in std::mem::take(&mut self.functions_to_finalize) { self.finalize_definitions();
let decl = self.declarations.get_function_decl(func);
debug_assert!(decl.linkage.is_definable());
self.finalize_function(func);
}
for data in std::mem::take(&mut self.data_objects_to_finalize) {
let decl = self.declarations.get_data_decl(data);
debug_assert!(decl.linkage.is_definable());
self.finalize_data(data);
}
// Now that we're done patching, prepare the memory for execution!
self.memory.readonly.set_readonly();
self.memory.code.set_readable_and_executable();
SimpleJITProduct { SimpleJITProduct {
memory: self.memory, memory: self.memory,
@@ -631,18 +666,19 @@ impl SimpleJITModule {
} }
#[cfg(not(windows))] #[cfg(not(windows))]
fn lookup_with_dlsym(name: &str) -> *const u8 { fn lookup_with_dlsym(name: &str) -> Option<*const u8> {
let c_str = CString::new(name).unwrap(); let c_str = CString::new(name).unwrap();
let c_str_ptr = c_str.as_ptr(); let c_str_ptr = c_str.as_ptr();
let sym = unsafe { libc::dlsym(libc::RTLD_DEFAULT, c_str_ptr) }; let sym = unsafe { libc::dlsym(libc::RTLD_DEFAULT, c_str_ptr) };
if sym.is_null() { if sym.is_null() {
panic!("can't resolve symbol {}", name); None
} else {
Some(sym as *const u8)
} }
sym as *const u8
} }
#[cfg(windows)] #[cfg(windows)]
fn lookup_with_dlsym(name: &str) -> *const u8 { fn lookup_with_dlsym(name: &str) -> Option<*const u8> {
const MSVCRT_DLL: &[u8] = b"msvcrt.dll\0"; const MSVCRT_DLL: &[u8] = b"msvcrt.dll\0";
let c_str = CString::new(name).unwrap(); let c_str = CString::new(name).unwrap();
@@ -661,26 +697,16 @@ fn lookup_with_dlsym(name: &str) -> *const u8 {
if addr.is_null() { if addr.is_null() {
continue; continue;
} }
return addr as *const u8; return Some(addr as *const u8);
} }
let msg = if handles[1].is_null() { None
"(msvcrt not loaded)"
} else {
""
};
panic!("cannot resolve address of symbol {} {}", name, msg);
} }
} }
#[derive(Default)]
struct SimpleJITRelocSink { struct SimpleJITRelocSink {
pub relocs: Vec<RelocRecord>, relocs: Vec<RelocRecord>,
}
impl SimpleJITRelocSink {
pub fn new() -> Self {
Self { relocs: Vec::new() }
}
} }
impl RelocSink for SimpleJITRelocSink { impl RelocSink for SimpleJITRelocSink {
@@ -729,16 +755,9 @@ impl RelocSink for SimpleJITRelocSink {
} }
} }
#[derive(Default)]
struct SimpleJITStackMapSink { struct SimpleJITStackMapSink {
pub stack_maps: Vec<StackMapRecord>, stack_maps: Vec<StackMapRecord>,
}
impl SimpleJITStackMapSink {
pub fn new() -> Self {
Self {
stack_maps: Vec::new(),
}
}
} }
impl StackMapSink for SimpleJITStackMapSink { impl StackMapSink for SimpleJITStackMapSink {