Consolidate methods of memory initialization (#3766)

* Consolidate methods of memory initialization

This commit consolidates the few locations that we have which are
performing memory initialization. Namely the uffd logic for creating
paged memory as well as the memfd logic for creating a memory image now
share an implementation to avoid duplicating bounds-checks or other
validation conditions. The main purpose of this commit is to fix a
fuzz-bug where a multiplication overflowed. The overflow itself was
benign but it seemed better to fix the overflow in only one place
instead of multiple.

The overflow in question is specifically when an initializer is checked
to be statically out-of-bounds and multiplies a memory's minimum size by
the wasm page size, returning the result as a `u64`. For
memory64-memories of size `1 << 48` this multiplication will overflow.
This was actually a preexisting bug with the `try_paged_init` function
which was copied for memfd, but cropped up here since memfd is used more
often than paged initialization. The fix here is to skip validation of
the `end` index if the size of memory is `1 << 64` since if the `end`
index can be represented as a `u64` then it's in-bounds. This is
somewhat of an esoteric case, though, since a memory of minimum size `1
<< 64` can't ever exist (we can't even ask the os for that much memory,
and even if we could it would fail).

* Fix memfd test

* Fix some tests

* Remove InitMemory enum

* Add an `is_segmented` helper method

* More clear variable name

* Make arguments to `init_memory` more descriptive
This commit is contained in:
Alex Crichton
2022-02-04 13:17:25 -06:00
committed by GitHub
parent a519e5ab64
commit 04d2caea7b
6 changed files with 428 additions and 332 deletions

View File

@@ -18,8 +18,8 @@ use std::sync::Arc;
use thiserror::Error;
use wasmtime_environ::{
DefinedFuncIndex, DefinedMemoryIndex, DefinedTableIndex, EntityRef, FunctionInfo, GlobalInit,
MemoryInitialization, MemoryInitializer, Module, ModuleType, PrimaryMap, SignatureIndex,
TableInitializer, TrapCode, WasmType, WASM_PAGE_SIZE,
InitMemory, MemoryInitialization, MemoryInitializer, Module, ModuleType, PrimaryMap,
SignatureIndex, TableInitializer, TrapCode, WasmType, WASM_PAGE_SIZE,
};
#[cfg(feature = "pooling-allocator")]
@@ -379,34 +379,60 @@ fn check_memory_init_bounds(
Ok(())
}
fn initialize_memories(
instance: &mut Instance,
module: &Module,
initializers: &[MemoryInitializer],
) -> Result<(), InstantiationError> {
for init in initializers {
// Check whether we can skip all initializers (due to, e.g.,
// memfd).
let memory = init.memory_index;
if let Some(defined_index) = module.defined_memory_index(memory) {
// We can only skip if there is actually a MemFD image. In
// some situations the MemFD image creation code will bail
// (e.g. due to an out of bounds data segment) and so we
// need to fall back on the usual initialization below.
if !instance.memories[defined_index].needs_init() {
continue;
}
}
fn initialize_memories(instance: &mut Instance, module: &Module) -> Result<(), InstantiationError> {
let memory_size_in_pages =
&|memory| (instance.get_memory(memory).current_length as u64) / u64::from(WASM_PAGE_SIZE);
instance
.memory_init_segment(
init.memory_index,
init.data.clone(),
get_memory_init_start(init, instance)?,
0,
init.data.end - init.data.start,
)
.map_err(InstantiationError::Trap)?;
// Loads the `global` value and returns it as a `u64`, but sign-extends
// 32-bit globals which can be used as the base for 32-bit memories.
let get_global_as_u64 = &|global| unsafe {
let def = if let Some(def_index) = instance.module.defined_global_index(global) {
instance.global(def_index)
} else {
&*instance.imported_global(global).from
};
if module.globals[global].wasm_ty == WasmType::I64 {
*def.as_u64()
} else {
u64::from(*def.as_u32())
}
};
// Delegates to the `init_memory` method which is sort of a duplicate of
// `instance.memory_init_segment` but is used at compile-time in other
// contexts so is shared here to have only one method of memory
// initialization.
//
// This call to `init_memory` notably implements all the bells and whistles
// so errors only happen if an out-of-bounds segment is found, in which case
// a trap is returned.
let ok = module.memory_initialization.init_memory(
InitMemory::Runtime {
memory_size_in_pages,
get_global_as_u64,
},
&mut |memory_index, offset, data| {
// If this initializer applies to a defined memory but that memory
// doesn't need initialization, due to something like uffd or memfd
// pre-initializing it via mmap magic, then this initializer can be
// skipped entirely.
if let Some(memory_index) = module.defined_memory_index(memory_index) {
if !instance.memories[memory_index].needs_init() {
return true;
}
}
let memory = instance.get_memory(memory_index);
let dst_slice =
unsafe { slice::from_raw_parts_mut(memory.base, memory.current_length) };
let dst = &mut dst_slice[usize::try_from(offset).unwrap()..][..data.len()];
dst.copy_from_slice(instance.wasm_data(data.clone()));
true
},
);
if !ok {
return Err(InstantiationError::Trap(Trap::wasm(
TrapCode::HeapOutOfBounds,
)));
}
Ok(())
@@ -416,16 +442,11 @@ fn check_init_bounds(instance: &mut Instance, module: &Module) -> Result<(), Ins
check_table_init_bounds(instance, module)?;
match &instance.module.memory_initialization {
MemoryInitialization::Paged { out_of_bounds, .. } => {
if *out_of_bounds {
return Err(InstantiationError::Link(LinkError(
"memory out of bounds: data segment does not fit".into(),
)));
}
}
MemoryInitialization::Segmented(initializers) => {
check_memory_init_bounds(instance, initializers)?;
}
// Statically validated already to have everything in-bounds.
MemoryInitialization::Paged { .. } => {}
}
Ok(())
@@ -448,40 +469,7 @@ fn initialize_instance(
initialize_tables(instance, module)?;
// Initialize the memories
match &module.memory_initialization {
MemoryInitialization::Paged { map, out_of_bounds } => {
for (index, pages) in map {
// Check whether the memory actually needs
// initialization. It may not if we're using a CoW
// mechanism like memfd.
if !instance.memories[index].needs_init() {
continue;
}
let memory = instance.memory(index);
let slice =
unsafe { slice::from_raw_parts_mut(memory.base, memory.current_length) };
for (page_index, page) in pages {
debug_assert_eq!(page.end - page.start, WASM_PAGE_SIZE);
let start = (*page_index * u64::from(WASM_PAGE_SIZE)) as usize;
let end = start + WASM_PAGE_SIZE as usize;
slice[start..end].copy_from_slice(instance.wasm_data(page.clone()));
}
}
// Check for out of bound access after initializing the pages to maintain
// the expected behavior of the bulk memory spec.
if *out_of_bounds {
return Err(InstantiationError::Trap(Trap::wasm(
TrapCode::HeapOutOfBounds,
)));
}
}
MemoryInitialization::Segmented(initializers) => {
initialize_memories(instance, module, initializers)?;
}
}
initialize_memories(instance, &module)?;
Ok(())
}

View File

@@ -1069,7 +1069,7 @@ unsafe impl InstanceAllocator for PoolingInstanceAllocator {
cfg_if::cfg_if! {
if #[cfg(all(feature = "uffd", target_os = "linux"))] {
match &module.memory_initialization {
wasmtime_environ::MemoryInitialization::Paged{ out_of_bounds, .. } => {
wasmtime_environ::MemoryInitialization::Paged { .. } => {
if !is_bulk_memory {
super::check_init_bounds(instance, module)?;
}
@@ -1079,13 +1079,6 @@ unsafe impl InstanceAllocator for PoolingInstanceAllocator {
// Don't initialize the memory; the fault handler will back the pages when accessed
// If there was an out of bounds access observed in initialization, return a trap
if *out_of_bounds {
return Err(InstantiationError::Trap(crate::traphandlers::Trap::wasm(
wasmtime_environ::TrapCode::HeapOutOfBounds,
)));
}
Ok(())
},
_ => initialize_instance(instance, module, is_bulk_memory)

View File

@@ -263,6 +263,7 @@ unsafe fn initialize_wasm_page(
) -> Result<()> {
// Check for paged initialization and copy the page if present in the initialization data
if let MemoryInitialization::Paged { map, .. } = &instance.module.memory_initialization {
let memory_index = instance.module().memory_index(memory_index);
let pages = &map[memory_index];
let pos = pages.binary_search_by_key(&(page_index as u64), |k| k.0);

View File

@@ -6,12 +6,10 @@ use anyhow::Result;
use libc::c_void;
use memfd::{Memfd, MemfdOptions};
use rustix::fd::AsRawFd;
use rustix::fs::FileExt;
use std::io::Write;
use std::sync::Arc;
use std::{convert::TryFrom, ops::Range};
use wasmtime_environ::{
DefinedMemoryIndex, MemoryInitialization, MemoryInitializer, MemoryPlan, Module, PrimaryMap,
};
use wasmtime_environ::{DefinedMemoryIndex, InitMemory, Module, PrimaryMap};
/// MemFDs containing backing images for certain memories in a module.
///
@@ -21,7 +19,7 @@ pub struct ModuleMemFds {
memories: PrimaryMap<DefinedMemoryIndex, Option<Arc<MemoryMemFd>>>,
}
const MAX_MEMFD_IMAGE_SIZE: u64 = 1024 * 1024 * 1024; // limit to 1GiB.
const MAX_MEMFD_IMAGE_SIZE: usize = 1024 * 1024 * 1024; // limit to 1GiB.
impl ModuleMemFds {
pub(crate) fn get_memory_image(
@@ -54,33 +52,6 @@ pub struct MemoryMemFd {
pub offset: usize,
}
fn unsupported_initializer(segment: &MemoryInitializer, plan: &MemoryPlan) -> bool {
// If the segment has a base that is dynamically determined
// (by a global value, which may be a function of an imported
// module, for example), then we cannot build a single static
// image that is used for every instantiation. So we skip this
// memory entirely.
let end = match segment.end() {
None => {
return true;
}
Some(end) => end,
};
// Cannot be out-of-bounds. If there is a *possibility* it may
// be, then we just fall back on ordinary initialization.
if plan.initializer_possibly_out_of_bounds(segment) {
return true;
}
// Must fit in our max size.
if end > MAX_MEMFD_IMAGE_SIZE {
return true;
}
false
}
fn create_memfd() -> Result<Memfd> {
// Create the memfd. It needs a name, but the
// documentation for `memfd_create()` says that names can
@@ -97,124 +68,104 @@ impl ModuleMemFds {
/// instantiation and execution by using memfd-backed memories.
pub fn new(module: &Module, wasm_data: &[u8]) -> Result<Option<Arc<ModuleMemFds>>> {
let page_size = region::page::size() as u64;
let page_align = |x: u64| x & !(page_size - 1);
let page_align_up = |x: u64| page_align(x + page_size - 1);
// First build up an in-memory image for each memory. This in-memory
// representation is discarded if the memory initializers aren't "of
// the right shape" where the desired shape is:
//
// * Only initializers for defined memories.
// * Only initializers with static offsets (no globals).
// * Only in-bound initializers.
//
// The `init_memory` method of `MemoryInitialization` is used here to
// do most of the validation for us, and otherwise the data chunks are
// collected into the `images` array here.
let mut images: PrimaryMap<DefinedMemoryIndex, Vec<u8>> = PrimaryMap::default();
let num_defined_memories = module.memory_plans.len() - module.num_imported_memories;
// Allocate a memfd file initially for every memory. We'll
// release those and set `excluded_memories` for those that we
// determine during initializer processing we cannot support a
// static image (e.g. due to dynamically-located segments).
let mut memfds: PrimaryMap<DefinedMemoryIndex, Option<Memfd>> = PrimaryMap::default();
let mut sizes: PrimaryMap<DefinedMemoryIndex, u64> = PrimaryMap::default();
let mut excluded_memories: PrimaryMap<DefinedMemoryIndex, bool> = PrimaryMap::new();
for _ in 0..num_defined_memories {
memfds.push(None);
sizes.push(0);
excluded_memories.push(false);
images.push(Vec::new());
}
let ok = module.memory_initialization.init_memory(
InitMemory::CompileTime(module),
&mut |memory, offset, data_range| {
// Memfd-based initialization of an imported memory isn't
// implemented right now, although might perhaps be
// theoretically possible for statically-known-in-bounds
// segments with page-aligned portions.
let memory = match module.defined_memory_index(memory) {
Some(index) => index,
None => return false,
};
// Splat the `data_range` into the `image` for this memory,
// updating it as necessary with 0s for holes and such.
let image = &mut images[memory];
let data = &wasm_data[data_range.start as usize..data_range.end as usize];
let offset = offset as usize;
let new_image_len = offset + data.len();
if image.len() < new_image_len {
if new_image_len > MAX_MEMFD_IMAGE_SIZE {
return false;
}
image.resize(new_image_len, 0);
}
image[offset..][..data.len()].copy_from_slice(data);
true
},
);
// If any initializer wasn't applicable then we skip memfds entirely.
if !ok {
return Ok(None);
}
let round_up_page = |len: u64| (len + page_size - 1) & !(page_size - 1);
match &module.memory_initialization {
&MemoryInitialization::Segmented(ref segments) => {
for (i, segment) in segments.iter().enumerate() {
let defined_memory = match module.defined_memory_index(segment.memory_index) {
Some(defined_memory) => defined_memory,
None => continue,
};
if excluded_memories[defined_memory] {
continue;
}
if unsupported_initializer(segment, &module.memory_plans[segment.memory_index])
{
memfds[defined_memory] = None;
excluded_memories[defined_memory] = true;
continue;
}
if memfds[defined_memory].is_none() {
memfds[defined_memory] = Some(create_memfd()?);
}
let memfd = memfds[defined_memory].as_mut().unwrap();
let end = round_up_page(segment.end().expect("must have statically-known end"));
if end > sizes[defined_memory] {
sizes[defined_memory] = end;
memfd.as_file().set_len(end)?;
}
let base = segments[i].offset;
let data = &wasm_data[segment.data.start as usize..segment.data.end as usize];
memfd.as_file().write_at(data, base)?;
}
}
&MemoryInitialization::Paged { ref map, .. } => {
for (defined_memory, pages) in map {
let top = pages
.iter()
.map(|(base, range)| *base + range.len() as u64)
.max()
.unwrap_or(0);
let memfd = create_memfd()?;
memfd.as_file().set_len(top)?;
for (base, range) in pages {
let data = &wasm_data[range.start as usize..range.end as usize];
memfd.as_file().write_at(data, *base)?;
}
memfds[defined_memory] = Some(memfd);
sizes[defined_memory] = top;
}
}
}
// Now finalize each memory.
let mut memories: PrimaryMap<DefinedMemoryIndex, Option<Arc<MemoryMemFd>>> =
PrimaryMap::default();
for (defined_memory, maybe_memfd) in memfds {
let memfd = match maybe_memfd {
Some(memfd) => memfd,
// With an in-memory representation of all memory images a `memfd` is
// now created and the data is pushed into the memfd. Note that the
// memfd representation will trim leading and trailing pages of zeros
// to store as little data as possible in the memfd. This is not only a
// performance improvement in the sense of "copy less data to the
// kernel" but it's also more performant to fault in zeros from
// anonymous-backed pages instead of memfd-backed pages-of-zeros (as
// the kernel knows anonymous mappings are always zero and has a cache
// of zero'd pages).
let mut memories = PrimaryMap::default();
for (defined_memory, image) in images {
// Find the first nonzero byte, and if all the bytes are zero then
// we can skip the memfd for this memory since there's no
// meaningful initialization.
let nonzero_start = match image.iter().position(|b| *b != 0) {
Some(i) => i as u64,
None => {
memories.push(None);
continue;
}
};
let size = sizes[defined_memory];
// Find leading and trailing zero data so that the mmap
// can precisely map only the nonzero data; anon-mmap zero
// memory is faster for anything that doesn't actually
// have content.
let mut page_data = vec![0; page_size as usize];
let mut page_is_nonzero = |page| {
let offset = page_size * page;
memfd.as_file().read_at(&mut page_data[..], offset).unwrap();
page_data.iter().any(|byte| *byte != 0)
};
let n_pages = size / page_size;
// Find the last nonzero byte, which must exist at this point since
// we found one going forward. Add one to find the index of the
// last zero, which may also be the length of the image.
let nonzero_end = image.iter().rposition(|b| *b != 0).unwrap() as u64 + 1;
let mut offset = 0;
for page in 0..n_pages {
if page_is_nonzero(page) {
break;
}
offset += page_size;
}
let len = if offset == size {
0
} else {
let mut len = 0;
for page in (0..n_pages).rev() {
if page_is_nonzero(page) {
len = (page + 1) * page_size - offset;
break;
}
}
len
};
// The offset of this image must be OS-page-aligned since we'll be
// starting the mmap at an aligned address. Align down the start
// index to the first index that's page aligned.
let offset = page_align(nonzero_start);
// The length of the image must also be page aligned and may reach
// beyond the end of the `image` array we have already. Take the
// length of the nonzero portion and then align it up to the page size.
let len = page_align_up(nonzero_end - offset);
// Write the nonzero data to the memfd and then use `set_len` to
// ensure that the length of the memfd is page-aligned where the gap
// at the end, if any, is filled with zeros.
let memfd = create_memfd()?;
memfd
.as_file()
.write_all(&image[offset as usize..nonzero_end as usize])?;
memfd.as_file().set_len(len)?;
// Seal the memfd's data and length.
//
@@ -239,11 +190,12 @@ impl ModuleMemFds {
assert_eq!(offset % page_size, 0);
assert_eq!(len % page_size, 0);
memories.push(Some(Arc::new(MemoryMemFd {
let idx = memories.push(Some(Arc::new(MemoryMemFd {
fd: memfd,
offset: usize::try_from(offset).unwrap(),
len: usize::try_from(len).unwrap(),
})));
assert_eq!(idx, defined_memory);
}
Ok(Some(Arc::new(ModuleMemFds { memories })))
@@ -457,7 +409,7 @@ impl MemFdSlot {
rustix::io::ProtFlags::READ | rustix::io::ProtFlags::WRITE,
rustix::io::MapFlags::PRIVATE | rustix::io::MapFlags::FIXED,
image.fd.as_file(),
image.offset as u64,
0,
)
.map_err(|e| InstantiationError::Resource(e.into()))?;
assert_eq!(ptr as usize, self.base + image.offset);
@@ -580,17 +532,19 @@ mod test {
use super::MemoryMemFd;
use crate::mmap::Mmap;
use anyhow::Result;
use rustix::fs::FileExt;
use std::io::Write;
fn create_memfd_with_data(offset: usize, data: &[u8]) -> Result<MemoryMemFd> {
// Offset must be page-aligned.
let page_size = region::page::size();
let memfd = create_memfd()?;
// Offset and length have to be page-aligned.
assert_eq!(offset & (page_size - 1), 0);
let image_len = offset + data.len();
let image_len = (image_len + page_size - 1) & !(page_size - 1);
let memfd = create_memfd()?;
memfd.as_file().write_all(data)?;
// The image length is rounded up to the nearest page size
let image_len = (data.len() + page_size - 1) & !(page_size - 1);
memfd.as_file().set_len(image_len as u64)?;
memfd.as_file().write_at(data, offset as u64)?;
Ok(MemoryMemFd {
fd: memfd,
len: image_len,