Check safety of as_raw with a simplified borrow checker (#37)
* wiggle-runtime: add as_raw method for [T] * add trivial borrow checker back in * integrate runtime borrow checker with as_raw methods * handle pointer arith overflow correctly in as_raw, create PtrOverflow error * runtime: add validation back to GuestType * generate: impl validate for enums, flags, handles, ints * oops! make validate its own method on trait GuestTypeTransparent * fix transparent impls for enum, flag, handle, int * some structs are transparent. fix tests. * tests: define byte_slice_strat and friends * wiggle-tests: i believe my allocator is working now * some type juggling around memset for ease of use * make GuestTypeTransparent an unsafe trait * delete redundant validation of pointer align * fix doc * wiggle_test: aha, you cant use sets to track memory areas * add multi-string test which exercises the runtime borrow checker against HostMemory::byte_slice_strat * oops left debug panic in * remove redundant (& incorrect, since unchecked) length calc * redesign validate again, and actually hook to as_raw * makr all validate impls as inline this should hopefully allow as_raw's check loop to be unrolled to a no-op in most cases! * code review fixes
This commit is contained in:
@@ -2,16 +2,34 @@ use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
|
||||
pub trait LifetimeExt {
|
||||
fn is_transparent(&self) -> bool;
|
||||
fn needs_lifetime(&self) -> bool;
|
||||
}
|
||||
|
||||
impl LifetimeExt for witx::TypeRef {
|
||||
fn is_transparent(&self) -> bool {
|
||||
self.type_().is_transparent()
|
||||
}
|
||||
fn needs_lifetime(&self) -> bool {
|
||||
self.type_().needs_lifetime()
|
||||
}
|
||||
}
|
||||
|
||||
impl LifetimeExt for witx::Type {
|
||||
fn is_transparent(&self) -> bool {
|
||||
match self {
|
||||
witx::Type::Builtin(b) => b.is_transparent(),
|
||||
witx::Type::Struct(s) => s.is_transparent(),
|
||||
witx::Type::Enum { .. }
|
||||
| witx::Type::Flags { .. }
|
||||
| witx::Type::Int { .. }
|
||||
| witx::Type::Handle { .. } => true,
|
||||
witx::Type::Union { .. }
|
||||
| witx::Type::Pointer { .. }
|
||||
| witx::Type::ConstPointer { .. }
|
||||
| witx::Type::Array { .. } => false,
|
||||
}
|
||||
}
|
||||
fn needs_lifetime(&self) -> bool {
|
||||
match self {
|
||||
witx::Type::Builtin(b) => b.needs_lifetime(),
|
||||
@@ -29,6 +47,9 @@ impl LifetimeExt for witx::Type {
|
||||
}
|
||||
|
||||
impl LifetimeExt for witx::BuiltinType {
|
||||
fn is_transparent(&self) -> bool {
|
||||
!self.needs_lifetime()
|
||||
}
|
||||
fn needs_lifetime(&self) -> bool {
|
||||
match self {
|
||||
witx::BuiltinType::String => true,
|
||||
@@ -38,12 +59,18 @@ impl LifetimeExt for witx::BuiltinType {
|
||||
}
|
||||
|
||||
impl LifetimeExt for witx::StructDatatype {
|
||||
fn is_transparent(&self) -> bool {
|
||||
self.members.iter().all(|m| m.tref.is_transparent())
|
||||
}
|
||||
fn needs_lifetime(&self) -> bool {
|
||||
self.members.iter().any(|m| m.tref.needs_lifetime())
|
||||
}
|
||||
}
|
||||
|
||||
impl LifetimeExt for witx::UnionDatatype {
|
||||
fn is_transparent(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn needs_lifetime(&self) -> bool {
|
||||
self.variants
|
||||
.iter()
|
||||
|
||||
@@ -87,8 +87,9 @@ pub(super) fn define_enum(names: &Names, name: &witx::Id, e: &witx::EnumDatatype
|
||||
|
||||
fn read(location: &wiggle_runtime::GuestPtr<#ident>) -> Result<#ident, wiggle_runtime::GuestError> {
|
||||
use std::convert::TryFrom;
|
||||
let val = #repr::read(&location.cast())?;
|
||||
#ident::try_from(val)
|
||||
let reprval = #repr::read(&location.cast())?;
|
||||
let value = #ident::try_from(reprval)?;
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
fn write(location: &wiggle_runtime::GuestPtr<'_, #ident>, val: Self)
|
||||
@@ -97,5 +98,16 @@ pub(super) fn define_enum(names: &Names, name: &witx::Id, e: &witx::EnumDatatype
|
||||
#repr::write(&location.cast(), #repr::from(val))
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl <'a> wiggle_runtime::GuestTypeTransparent<'a> for #ident {
|
||||
#[inline]
|
||||
fn validate(location: *mut #ident) -> Result<(), wiggle_runtime::GuestError> {
|
||||
use std::convert::TryFrom;
|
||||
// Validate value in memory using #ident::try_from(reprval)
|
||||
let reprval = unsafe { (location as *mut #repr).read() };
|
||||
let _val = #ident::try_from(reprval)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,10 +134,11 @@ pub(super) fn define_flags(names: &Names, name: &witx::Id, f: &witx::FlagsDataty
|
||||
#repr::guest_align()
|
||||
}
|
||||
|
||||
fn read(location: &wiggle_runtime::GuestPtr<'a, #ident>) -> Result<#ident, wiggle_runtime::GuestError> {
|
||||
fn read(location: &wiggle_runtime::GuestPtr<#ident>) -> Result<#ident, wiggle_runtime::GuestError> {
|
||||
use std::convert::TryFrom;
|
||||
let bits = #repr::read(&location.cast())?;
|
||||
#ident::try_from(bits)
|
||||
let reprval = #repr::read(&location.cast())?;
|
||||
let value = #ident::try_from(reprval)?;
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
fn write(location: &wiggle_runtime::GuestPtr<'_, #ident>, val: Self) -> Result<(), wiggle_runtime::GuestError> {
|
||||
@@ -145,5 +146,16 @@ pub(super) fn define_flags(names: &Names, name: &witx::Id, f: &witx::FlagsDataty
|
||||
#repr::write(&location.cast(), val)
|
||||
}
|
||||
}
|
||||
unsafe impl <'a> wiggle_runtime::GuestTypeTransparent<'a> for #ident {
|
||||
#[inline]
|
||||
fn validate(location: *mut #ident) -> Result<(), wiggle_runtime::GuestError> {
|
||||
use std::convert::TryFrom;
|
||||
// Validate value in memory using #ident::try_from(reprval)
|
||||
let reprval = unsafe { (location as *mut #repr).read() };
|
||||
let _val = #ident::try_from(reprval)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ pub(super) fn define_handle(
|
||||
let size = h.mem_size_align().size as u32;
|
||||
let align = h.mem_size_align().align as usize;
|
||||
quote! {
|
||||
#[repr(transparent)]
|
||||
#[derive(Copy, Clone, Debug, ::std::hash::Hash, Eq, PartialEq)]
|
||||
pub struct #ident(u32);
|
||||
|
||||
@@ -62,5 +63,15 @@ pub(super) fn define_handle(
|
||||
u32::write(&location.cast(), val.0)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<'a> wiggle_runtime::GuestTypeTransparent<'a> for #ident {
|
||||
#[inline]
|
||||
fn validate(_location: *mut #ident) -> Result<(), wiggle_runtime::GuestError> {
|
||||
// All bit patterns accepted
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,11 +73,21 @@ pub(super) fn define_int(names: &Names, name: &witx::Id, i: &witx::IntDatatype)
|
||||
|
||||
fn read(location: &wiggle_runtime::GuestPtr<'a, #ident>) -> Result<#ident, wiggle_runtime::GuestError> {
|
||||
Ok(#ident(#repr::read(&location.cast())?))
|
||||
|
||||
}
|
||||
|
||||
fn write(location: &wiggle_runtime::GuestPtr<'_, #ident>, val: Self) -> Result<(), wiggle_runtime::GuestError> {
|
||||
#repr::write(&location.cast(), val.0)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<'a> wiggle_runtime::GuestTypeTransparent<'a> for #ident {
|
||||
#[inline]
|
||||
fn validate(_location: *mut #ident) -> Result<(), wiggle_runtime::GuestError> {
|
||||
// All bit patterns accepted
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,6 +77,32 @@ pub(super) fn define_struct(
|
||||
(quote!(), quote!(, Copy, PartialEq))
|
||||
};
|
||||
|
||||
let transparent = if s.is_transparent() {
|
||||
let member_validate = s.member_layout().into_iter().map(|ml| {
|
||||
let offset = ml.offset;
|
||||
let typename = names.type_ref(&ml.member.tref, anon_lifetime());
|
||||
quote! {
|
||||
// SAFETY: caller has validated bounds and alignment of `location`.
|
||||
// member_layout gives correctly-aligned pointers inside that area.
|
||||
#typename::validate(
|
||||
unsafe { (location as *mut u8).add(#offset) as *mut _ }
|
||||
)?;
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
unsafe impl<'a> wiggle_runtime::GuestTypeTransparent<'a> for #ident {
|
||||
#[inline]
|
||||
fn validate(location: *mut #ident) -> Result<(), wiggle_runtime::GuestError> {
|
||||
#(#member_validate)*
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote!()
|
||||
};
|
||||
|
||||
quote! {
|
||||
#[derive(Clone, Debug #extra_derive)]
|
||||
pub struct #ident #struct_lifetime {
|
||||
@@ -102,5 +128,7 @@ pub(super) fn define_struct(
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#transparent
|
||||
}
|
||||
}
|
||||
|
||||
91
crates/runtime/src/borrow.rs
Normal file
91
crates/runtime/src/borrow.rs
Normal file
@@ -0,0 +1,91 @@
|
||||
use crate::region::Region;
|
||||
use crate::GuestError;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct GuestBorrows {
|
||||
borrows: Vec<Region>,
|
||||
}
|
||||
|
||||
impl GuestBorrows {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
borrows: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_borrowed(&self, r: Region) -> bool {
|
||||
!self.borrows.iter().all(|b| !b.overlaps(r))
|
||||
}
|
||||
|
||||
pub fn borrow(&mut self, r: Region) -> Result<(), GuestError> {
|
||||
if self.is_borrowed(r) {
|
||||
Err(GuestError::PtrBorrowed(r))
|
||||
} else {
|
||||
self.borrows.push(r);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
#[test]
|
||||
fn nonoverlapping() {
|
||||
let mut bs = GuestBorrows::new();
|
||||
let r1 = Region::new(0, 10);
|
||||
let r2 = Region::new(10, 10);
|
||||
assert!(!r1.overlaps(r2));
|
||||
bs.borrow(r1).expect("can borrow r1");
|
||||
bs.borrow(r2).expect("can borrow r2");
|
||||
|
||||
let mut bs = GuestBorrows::new();
|
||||
let r1 = Region::new(10, 10);
|
||||
let r2 = Region::new(0, 10);
|
||||
assert!(!r1.overlaps(r2));
|
||||
bs.borrow(r1).expect("can borrow r1");
|
||||
bs.borrow(r2).expect("can borrow r2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn overlapping() {
|
||||
let mut bs = GuestBorrows::new();
|
||||
let r1 = Region::new(0, 10);
|
||||
let r2 = Region::new(9, 10);
|
||||
assert!(r1.overlaps(r2));
|
||||
bs.borrow(r1).expect("can borrow r1");
|
||||
assert!(bs.borrow(r2).is_err(), "cant borrow r2");
|
||||
|
||||
let mut bs = GuestBorrows::new();
|
||||
let r1 = Region::new(0, 10);
|
||||
let r2 = Region::new(2, 5);
|
||||
assert!(r1.overlaps(r2));
|
||||
bs.borrow(r1).expect("can borrow r1");
|
||||
assert!(bs.borrow(r2).is_err(), "cant borrow r2");
|
||||
|
||||
let mut bs = GuestBorrows::new();
|
||||
let r1 = Region::new(9, 10);
|
||||
let r2 = Region::new(0, 10);
|
||||
assert!(r1.overlaps(r2));
|
||||
bs.borrow(r1).expect("can borrow r1");
|
||||
assert!(bs.borrow(r2).is_err(), "cant borrow r2");
|
||||
|
||||
let mut bs = GuestBorrows::new();
|
||||
let r1 = Region::new(2, 5);
|
||||
let r2 = Region::new(0, 10);
|
||||
assert!(r1.overlaps(r2));
|
||||
bs.borrow(r1).expect("can borrow r1");
|
||||
assert!(bs.borrow(r2).is_err(), "cant borrow r2");
|
||||
|
||||
let mut bs = GuestBorrows::new();
|
||||
let r1 = Region::new(2, 5);
|
||||
let r2 = Region::new(10, 5);
|
||||
let r3 = Region::new(15, 5);
|
||||
let r4 = Region::new(0, 10);
|
||||
assert!(r1.overlaps(r4));
|
||||
bs.borrow(r1).expect("can borrow r1");
|
||||
bs.borrow(r2).expect("can borrow r2");
|
||||
bs.borrow(r3).expect("can borrow r3");
|
||||
assert!(bs.borrow(r4).is_err(), "cant borrow r4");
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,8 @@ pub enum GuestError {
|
||||
InvalidFlagValue(&'static str),
|
||||
#[error("Invalid enum value {0}")]
|
||||
InvalidEnumValue(&'static str),
|
||||
#[error("Pointer overflow")]
|
||||
PtrOverflow,
|
||||
#[error("Pointer out of bounds: {0:?}")]
|
||||
PtrOutOfBounds(Region),
|
||||
#[error("Pointer not aligned to {1}: {0:?}")]
|
||||
|
||||
@@ -40,6 +40,22 @@ pub trait GuestType<'a>: Sized {
|
||||
fn write(ptr: &GuestPtr<'_, Self>, val: Self) -> Result<(), GuestError>;
|
||||
}
|
||||
|
||||
/// A trait for `GuestType`s that have the same representation in guest memory
|
||||
/// as in Rust. These types can be used with the `GuestPtr::as_raw` method to
|
||||
/// view as a slice.
|
||||
///
|
||||
/// Unsafe trait because a correct GuestTypeTransparent implemengation ensures that the
|
||||
/// GuestPtr::as_raw methods are safe. This trait should only ever be implemented
|
||||
/// by wiggle_generate-produced code.
|
||||
pub unsafe trait GuestTypeTransparent<'a>: GuestType<'a> {
|
||||
/// Checks that the memory at `ptr` is a valid representation of `Self`.
|
||||
///
|
||||
/// Assumes that memory safety checks have already been performed: `ptr`
|
||||
/// has been checked to be aligned correctly and reside in memory using
|
||||
/// `GuestMemory::validate_size_align`
|
||||
fn validate(ptr: *mut Self) -> Result<(), GuestError>;
|
||||
}
|
||||
|
||||
macro_rules! primitives {
|
||||
($($i:ident)*) => ($(
|
||||
impl<'a> GuestType<'a> for $i {
|
||||
@@ -78,6 +94,15 @@ macro_rules! primitives {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<'a> GuestTypeTransparent<'a> for $i {
|
||||
#[inline]
|
||||
fn validate(_ptr: *mut $i) -> Result<(), GuestError> {
|
||||
// All bit patterns are safe, nothing to do here
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
)*)
|
||||
}
|
||||
|
||||
|
||||
@@ -6,11 +6,14 @@ use std::slice;
|
||||
use std::str;
|
||||
use std::sync::Arc;
|
||||
|
||||
mod borrow;
|
||||
mod error;
|
||||
mod guest_type;
|
||||
mod region;
|
||||
|
||||
pub use borrow::GuestBorrows;
|
||||
pub use error::GuestError;
|
||||
pub use guest_type::{GuestErrorType, GuestType};
|
||||
pub use guest_type::{GuestErrorType, GuestType, GuestTypeTransparent};
|
||||
pub use region::Region;
|
||||
|
||||
/// A trait which abstracts how to get at the region of host memory taht
|
||||
@@ -119,12 +122,12 @@ pub unsafe trait GuestMemory {
|
||||
// Figure out our pointer to the start of memory
|
||||
let start = match (base_ptr as usize).checked_add(offset as usize) {
|
||||
Some(ptr) => ptr,
|
||||
None => return Err(GuestError::PtrOutOfBounds(region)),
|
||||
None => return Err(GuestError::PtrOverflow),
|
||||
};
|
||||
// and use that to figure out the end pointer
|
||||
let end = match start.checked_add(len as usize) {
|
||||
Some(ptr) => ptr,
|
||||
None => return Err(GuestError::PtrOutOfBounds(region)),
|
||||
None => return Err(GuestError::PtrOverflow),
|
||||
};
|
||||
// and then verify that our end doesn't reach past the end of our memory
|
||||
if end > (base_ptr as usize) + (base_len as usize) {
|
||||
@@ -335,7 +338,7 @@ impl<'a, T: ?Sized + Pointee> GuestPtr<'a, T> {
|
||||
.and_then(|o| self.pointer.checked_add(o));
|
||||
let offset = match offset {
|
||||
Some(o) => o,
|
||||
None => return Err(GuestError::InvalidFlagValue("")),
|
||||
None => return Err(GuestError::PtrOverflow),
|
||||
};
|
||||
Ok(GuestPtr::new(self.mem, offset))
|
||||
}
|
||||
@@ -369,6 +372,54 @@ impl<'a, T> GuestPtr<'a, [T]> {
|
||||
(0..self.len()).map(move |i| base.add(i))
|
||||
}
|
||||
|
||||
/// Attempts to read a raw `*mut [T]` pointer from this pointer, performing
|
||||
/// bounds checks and type validation.
|
||||
/// The resulting `*mut [T]` can be used as a `&mut [t]` as long as the
|
||||
/// reference is dropped before any Wasm code is re-entered.
|
||||
///
|
||||
/// This function will return a raw pointer into host memory if all checks
|
||||
/// succeed (valid utf-8, valid pointers, etc). If any checks fail then
|
||||
/// `GuestError` will be returned.
|
||||
///
|
||||
/// Note that the `*mut [T]` pointer is still unsafe to use in general, but
|
||||
/// there are specific situations that it is safe to use. For more
|
||||
/// information about using the raw pointer, consult the [`GuestMemory`]
|
||||
/// trait documentation.
|
||||
///
|
||||
/// For safety against overlapping mutable borrows, the user must use the
|
||||
/// same `GuestBorrows` to create all *mut str or *mut [T] that are alive
|
||||
/// at the same time.
|
||||
pub fn as_raw(&self, bc: &mut GuestBorrows) -> Result<*mut [T], GuestError>
|
||||
where
|
||||
T: GuestTypeTransparent<'a>,
|
||||
{
|
||||
let len = match self.pointer.1.checked_mul(T::guest_size()) {
|
||||
Some(l) => l,
|
||||
None => return Err(GuestError::PtrOverflow),
|
||||
};
|
||||
let ptr =
|
||||
self.mem
|
||||
.validate_size_align(self.pointer.0, T::guest_align(), len)? as *mut T;
|
||||
|
||||
bc.borrow(Region {
|
||||
start: self.pointer.0,
|
||||
len,
|
||||
})?;
|
||||
|
||||
// Validate all elements in slice.
|
||||
// SAFETY: ptr has been validated by self.mem.validate_size_align
|
||||
for offs in 0..self.pointer.1 {
|
||||
T::validate(unsafe { ptr.add(offs as usize) })?;
|
||||
}
|
||||
|
||||
// SAFETY: iff there are no overlapping borrows (all uses of as_raw use this same
|
||||
// GuestBorrows), its valid to construct a *mut [T]
|
||||
unsafe {
|
||||
let s = slice::from_raw_parts_mut(ptr, self.pointer.1 as usize);
|
||||
Ok(s as *mut [T])
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a `GuestPtr` pointing to the base of the array for the interior
|
||||
/// type `T`.
|
||||
pub fn as_ptr(&self) -> GuestPtr<'a, T> {
|
||||
@@ -396,6 +447,8 @@ impl<'a> GuestPtr<'a, str> {
|
||||
|
||||
/// Attempts to read a raw `*mut str` pointer from this pointer, performing
|
||||
/// bounds checks and utf-8 checks.
|
||||
/// The resulting `*mut str` can be used as a `&mut str` as long as the
|
||||
/// reference is dropped before any Wasm code is re-entered.
|
||||
///
|
||||
/// This function will return a raw pointer into host memory if all checks
|
||||
/// succeed (valid utf-8, valid pointers, etc). If any checks fail then
|
||||
@@ -405,12 +458,22 @@ impl<'a> GuestPtr<'a, str> {
|
||||
/// there are specific situations that it is safe to use. For more
|
||||
/// information about using the raw pointer, consult the [`GuestMemory`]
|
||||
/// trait documentation.
|
||||
pub fn as_raw(&self) -> Result<*mut str, GuestError> {
|
||||
///
|
||||
/// For safety against overlapping mutable borrows, the user must use the
|
||||
/// same `GuestBorrows` to create all *mut str or *mut [T] that are alive
|
||||
/// at the same time.
|
||||
pub fn as_raw(&self, bc: &mut GuestBorrows) -> Result<*mut str, GuestError> {
|
||||
let ptr = self
|
||||
.mem
|
||||
.validate_size_align(self.pointer.0, 1, self.pointer.1)?;
|
||||
|
||||
// TODO: doc unsafety here
|
||||
bc.borrow(Region {
|
||||
start: self.pointer.0,
|
||||
len: self.pointer.1,
|
||||
})?;
|
||||
|
||||
// SAFETY: iff there are no overlapping borrows (all uses of as_raw use this same
|
||||
// GuestBorrows), its valid to construct a *mut str
|
||||
unsafe {
|
||||
let s = slice::from_raw_parts_mut(ptr, self.pointer.1 as usize);
|
||||
match str::from_utf8_mut(s) {
|
||||
|
||||
@@ -2,6 +2,45 @@ use proptest::prelude::*;
|
||||
use std::cell::UnsafeCell;
|
||||
use wiggle_runtime::GuestMemory;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MemAreas(Vec<MemArea>);
|
||||
impl MemAreas {
|
||||
pub fn new() -> Self {
|
||||
MemAreas(Vec::new())
|
||||
}
|
||||
pub fn insert(&mut self, a: MemArea) {
|
||||
// Find if `a` is already in the vector
|
||||
match self.0.binary_search(&a) {
|
||||
// It is present - insert it next to existing one
|
||||
Ok(loc) => self.0.insert(loc, a),
|
||||
// It is not present - heres where to insert it
|
||||
Err(loc) => self.0.insert(loc, a),
|
||||
}
|
||||
}
|
||||
pub fn iter(&self) -> impl Iterator<Item = &MemArea> {
|
||||
self.0.iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> From<R> for MemAreas
|
||||
where
|
||||
R: AsRef<[MemArea]>,
|
||||
{
|
||||
fn from(ms: R) -> MemAreas {
|
||||
let mut out = MemAreas::new();
|
||||
for m in ms.as_ref().into_iter() {
|
||||
out.insert(*m);
|
||||
}
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<Vec<MemArea>> for MemAreas {
|
||||
fn into(self) -> Vec<MemArea> {
|
||||
self.0.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(align(4096))]
|
||||
pub struct HostMemory {
|
||||
buffer: UnsafeCell<[u8; 4096]>,
|
||||
@@ -26,6 +65,42 @@ impl HostMemory {
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// Takes a sorted list or memareas, and gives a sorted list of memareas covering
|
||||
/// the parts of memory not covered by the previous
|
||||
pub fn invert(regions: &MemAreas) -> MemAreas {
|
||||
let mut out = MemAreas::new();
|
||||
let mut start = 0;
|
||||
for r in regions.iter() {
|
||||
let len = r.ptr - start;
|
||||
if len > 0 {
|
||||
out.insert(MemArea {
|
||||
ptr: start,
|
||||
len: r.ptr - start,
|
||||
});
|
||||
}
|
||||
start = r.ptr + r.len;
|
||||
}
|
||||
if start < 4096 {
|
||||
out.insert(MemArea {
|
||||
ptr: start,
|
||||
len: 4096 - start,
|
||||
});
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
pub fn byte_slice_strat(size: u32, exclude: &MemAreas) -> BoxedStrategy<MemArea> {
|
||||
let available: Vec<MemArea> = Self::invert(exclude)
|
||||
.iter()
|
||||
.flat_map(|a| a.inside(size))
|
||||
.collect();
|
||||
|
||||
Just(available)
|
||||
.prop_filter("available memory for allocation", |a| !a.is_empty())
|
||||
.prop_flat_map(|a| prop::sample::select(a))
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl GuestMemory for HostMemory {
|
||||
@@ -37,7 +112,7 @@ unsafe impl GuestMemory for HostMemory {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct MemArea {
|
||||
pub ptr: u32,
|
||||
pub len: u32,
|
||||
@@ -48,7 +123,7 @@ impl MemArea {
|
||||
// test.
|
||||
// So, I implemented this one with std::ops::Range so it is less likely I wrote the same bug in two
|
||||
// places.
|
||||
pub fn overlapping(&self, b: &Self) -> bool {
|
||||
pub fn overlapping(&self, b: Self) -> bool {
|
||||
// a_range is all elems in A
|
||||
let a_range = std::ops::Range {
|
||||
start: self.ptr,
|
||||
@@ -73,18 +148,33 @@ impl MemArea {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
pub fn non_overlapping_set(areas: &[&Self]) -> bool {
|
||||
// A is all areas
|
||||
for (i, a) in areas.iter().enumerate() {
|
||||
// (A, B) is every pair of areas
|
||||
for b in areas[i + 1..].iter() {
|
||||
if a.overlapping(b) {
|
||||
return false;
|
||||
pub fn non_overlapping_set<M>(areas: M) -> bool
|
||||
where
|
||||
M: Into<MemAreas>,
|
||||
{
|
||||
let areas = areas.into();
|
||||
for (aix, a) in areas.iter().enumerate() {
|
||||
for (bix, b) in areas.iter().enumerate() {
|
||||
if aix != bix {
|
||||
// (A, B) is every pairing of areas
|
||||
if a.overlapping(*b) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Enumerate all memareas of size `len` inside a given area
|
||||
fn inside(&self, len: u32) -> impl Iterator<Item = MemArea> {
|
||||
let end: i64 = self.len as i64 - len as i64;
|
||||
let start = self.ptr;
|
||||
(0..end).into_iter().map(move |v| MemArea {
|
||||
ptr: start + v as u32,
|
||||
len,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -97,6 +187,104 @@ mod test {
|
||||
let h = Box::new(h);
|
||||
assert_eq!(h.base().0 as usize % 4096, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invert() {
|
||||
fn invert_equality(input: &[MemArea], expected: &[MemArea]) {
|
||||
let input: MemAreas = input.into();
|
||||
let inverted: Vec<MemArea> = HostMemory::invert(&input).into();
|
||||
assert_eq!(expected, inverted.as_slice());
|
||||
}
|
||||
|
||||
invert_equality(&[], &[MemArea { ptr: 0, len: 4096 }]);
|
||||
invert_equality(
|
||||
&[MemArea { ptr: 0, len: 1 }],
|
||||
&[MemArea { ptr: 1, len: 4095 }],
|
||||
);
|
||||
|
||||
invert_equality(
|
||||
&[MemArea { ptr: 1, len: 1 }],
|
||||
&[MemArea { ptr: 0, len: 1 }, MemArea { ptr: 2, len: 4094 }],
|
||||
);
|
||||
|
||||
invert_equality(
|
||||
&[MemArea { ptr: 1, len: 4095 }],
|
||||
&[MemArea { ptr: 0, len: 1 }],
|
||||
);
|
||||
|
||||
invert_equality(
|
||||
&[MemArea { ptr: 0, len: 1 }, MemArea { ptr: 1, len: 4095 }],
|
||||
&[],
|
||||
);
|
||||
|
||||
invert_equality(
|
||||
&[MemArea { ptr: 1, len: 2 }, MemArea { ptr: 4, len: 1 }],
|
||||
&[
|
||||
MemArea { ptr: 0, len: 1 },
|
||||
MemArea { ptr: 3, len: 1 },
|
||||
MemArea { ptr: 5, len: 4091 },
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
fn set_of_slices_strat(
|
||||
s1: u32,
|
||||
s2: u32,
|
||||
s3: u32,
|
||||
) -> BoxedStrategy<(MemArea, MemArea, MemArea)> {
|
||||
HostMemory::byte_slice_strat(s1, &MemAreas::new())
|
||||
.prop_flat_map(move |a1| {
|
||||
(
|
||||
Just(a1),
|
||||
HostMemory::byte_slice_strat(s2, &MemAreas::from(&[a1])),
|
||||
)
|
||||
})
|
||||
.prop_flat_map(move |(a1, a2)| {
|
||||
(
|
||||
Just(a1),
|
||||
Just(a2),
|
||||
HostMemory::byte_slice_strat(s3, &MemAreas::from(&[a1, a2])),
|
||||
)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trivial_inside() {
|
||||
let a = MemArea { ptr: 24, len: 4072 };
|
||||
let interior = a.inside(24).collect::<Vec<_>>();
|
||||
|
||||
assert!(interior.len() > 0);
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
// For some random region of decent size
|
||||
fn inside(r in HostMemory::mem_area_strat(123)) {
|
||||
let set_of_r = MemAreas::from(&[r]);
|
||||
// All regions outside of r:
|
||||
let exterior = HostMemory::invert(&set_of_r);
|
||||
// All regions inside of r:
|
||||
let interior = r.inside(22);
|
||||
for i in interior {
|
||||
// i overlaps with r:
|
||||
assert!(r.overlapping(i));
|
||||
// i is inside r:
|
||||
assert!(i.ptr >= r.ptr);
|
||||
assert!(r.ptr + r.len >= i.ptr + i.len);
|
||||
// the set of exterior and i is non-overlapping
|
||||
let mut all = exterior.clone();
|
||||
all.insert(i);
|
||||
assert!(MemArea::non_overlapping_set(all));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn byte_slices((s1, s2, s3) in set_of_slices_strat(12, 34, 56)) {
|
||||
let all = MemAreas::from(&[s1, s2, s3]);
|
||||
assert!(MemArea::non_overlapping_set(all));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use std::cell::RefCell;
|
||||
|
||||
Reference in New Issue
Block a user