wiggle: automate borrow checking, explicitly passing borrow checker throughout
This commit is contained in:
@@ -1,64 +1,67 @@
|
||||
use crate::error::GuestError;
|
||||
use crate::region::Region;
|
||||
use crate::{GuestError, GuestPtr, GuestType};
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct GuestBorrows {
|
||||
borrows: Vec<Region>,
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct BorrowHandle(usize);
|
||||
|
||||
pub struct BorrowChecker {
|
||||
bc: RefCell<InnerBorrowChecker>,
|
||||
}
|
||||
|
||||
impl GuestBorrows {
|
||||
impl BorrowChecker {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
borrows: Vec::new(),
|
||||
BorrowChecker {
|
||||
bc: RefCell::new(InnerBorrowChecker::new()),
|
||||
}
|
||||
}
|
||||
pub fn borrow(&self, r: Region) -> Result<BorrowHandle, GuestError> {
|
||||
self.bc
|
||||
.borrow_mut()
|
||||
.borrow(r)
|
||||
.ok_or_else(|| GuestError::PtrBorrowed(r))
|
||||
}
|
||||
pub fn unborrow(&self, h: BorrowHandle) {
|
||||
self.bc.borrow_mut().unborrow(h)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct InnerBorrowChecker {
|
||||
borrows: HashMap<BorrowHandle, Region>,
|
||||
next_handle: BorrowHandle,
|
||||
}
|
||||
|
||||
impl InnerBorrowChecker {
|
||||
fn new() -> Self {
|
||||
InnerBorrowChecker {
|
||||
borrows: HashMap::new(),
|
||||
next_handle: BorrowHandle(0),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_borrowed(&self, r: Region) -> bool {
|
||||
!self.borrows.iter().all(|b| !b.overlaps(r))
|
||||
!self.borrows.values().all(|b| !b.overlaps(r))
|
||||
}
|
||||
|
||||
pub(crate) fn borrow(&mut self, r: Region) -> Result<(), GuestError> {
|
||||
fn new_handle(&mut self) -> BorrowHandle {
|
||||
let h = self.next_handle;
|
||||
self.next_handle = BorrowHandle(h.0 + 1);
|
||||
h
|
||||
}
|
||||
|
||||
fn borrow(&mut self, r: Region) -> Option<BorrowHandle> {
|
||||
if self.is_borrowed(r) {
|
||||
Err(GuestError::PtrBorrowed(r))
|
||||
} else {
|
||||
self.borrows.push(r);
|
||||
Ok(())
|
||||
return None;
|
||||
}
|
||||
let h = self.new_handle();
|
||||
self.borrows.insert(h, r);
|
||||
Some(h)
|
||||
}
|
||||
|
||||
/// Borrow the region of memory pointed to by a `GuestPtr`. This is required for safety if
|
||||
/// you are dereferencing `GuestPtr`s while holding a reference to a slice via
|
||||
/// `GuestPtr::as_raw`.
|
||||
pub fn borrow_pointee<'a, T>(&mut self, p: &GuestPtr<'a, T>) -> Result<(), GuestError>
|
||||
where
|
||||
T: GuestType<'a>,
|
||||
{
|
||||
self.borrow(Region {
|
||||
start: p.offset(),
|
||||
len: T::guest_size(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Borrow the slice of memory pointed to by a `GuestPtr<[T]>`. This is required for safety if
|
||||
/// you are dereferencing the `GuestPtr`s while holding a reference to another slice via
|
||||
/// `GuestPtr::as_raw`. Not required if using `GuestPtr::as_raw` on this pointer.
|
||||
pub fn borrow_slice<'a, T>(&mut self, p: &GuestPtr<'a, [T]>) -> Result<(), GuestError>
|
||||
where
|
||||
T: GuestType<'a>,
|
||||
{
|
||||
let (start, elems) = p.offset();
|
||||
let len = T::guest_size()
|
||||
.checked_mul(elems)
|
||||
.ok_or_else(|| GuestError::PtrOverflow)?;
|
||||
self.borrow(Region { start, len })
|
||||
}
|
||||
|
||||
/// Borrow the slice of memory pointed to by a `GuestPtr<str>`. This is required for safety if
|
||||
/// you are dereferencing the `GuestPtr`s while holding a reference to another slice via
|
||||
/// `GuestPtr::as_raw`. Not required if using `GuestPtr::as_raw` on this pointer.
|
||||
pub fn borrow_str(&mut self, p: &GuestPtr<str>) -> Result<(), GuestError> {
|
||||
let (start, len) = p.offset();
|
||||
self.borrow(Region { start, len })
|
||||
fn unborrow(&mut self, h: BorrowHandle) {
|
||||
let _ = self.borrows.remove(&h);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,14 +70,14 @@ mod test {
|
||||
use super::*;
|
||||
#[test]
|
||||
fn nonoverlapping() {
|
||||
let mut bs = GuestBorrows::new();
|
||||
let mut bs = InnerBorrowChecker::new();
|
||||
let r1 = Region::new(0, 10);
|
||||
let r2 = Region::new(10, 10);
|
||||
assert!(!r1.overlaps(r2));
|
||||
bs.borrow(r1).expect("can borrow r1");
|
||||
bs.borrow(r2).expect("can borrow r2");
|
||||
|
||||
let mut bs = GuestBorrows::new();
|
||||
let mut bs = InnerBorrowChecker::new();
|
||||
let r1 = Region::new(10, 10);
|
||||
let r2 = Region::new(0, 10);
|
||||
assert!(!r1.overlaps(r2));
|
||||
@@ -84,35 +87,35 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn overlapping() {
|
||||
let mut bs = GuestBorrows::new();
|
||||
let mut bs = InnerBorrowChecker::new();
|
||||
let r1 = Region::new(0, 10);
|
||||
let r2 = Region::new(9, 10);
|
||||
assert!(r1.overlaps(r2));
|
||||
bs.borrow(r1).expect("can borrow r1");
|
||||
assert!(bs.borrow(r2).is_err(), "cant borrow r2");
|
||||
assert!(bs.borrow(r2).is_none(), "cant borrow r2");
|
||||
|
||||
let mut bs = GuestBorrows::new();
|
||||
let mut bs = InnerBorrowChecker::new();
|
||||
let r1 = Region::new(0, 10);
|
||||
let r2 = Region::new(2, 5);
|
||||
assert!(r1.overlaps(r2));
|
||||
bs.borrow(r1).expect("can borrow r1");
|
||||
assert!(bs.borrow(r2).is_err(), "cant borrow r2");
|
||||
assert!(bs.borrow(r2).is_none(), "cant borrow r2");
|
||||
|
||||
let mut bs = GuestBorrows::new();
|
||||
let mut bs = InnerBorrowChecker::new();
|
||||
let r1 = Region::new(9, 10);
|
||||
let r2 = Region::new(0, 10);
|
||||
assert!(r1.overlaps(r2));
|
||||
bs.borrow(r1).expect("can borrow r1");
|
||||
assert!(bs.borrow(r2).is_err(), "cant borrow r2");
|
||||
assert!(bs.borrow(r2).is_none(), "cant borrow r2");
|
||||
|
||||
let mut bs = GuestBorrows::new();
|
||||
let mut bs = InnerBorrowChecker::new();
|
||||
let r1 = Region::new(2, 5);
|
||||
let r2 = Region::new(0, 10);
|
||||
assert!(r1.overlaps(r2));
|
||||
bs.borrow(r1).expect("can borrow r1");
|
||||
assert!(bs.borrow(r2).is_err(), "cant borrow r2");
|
||||
assert!(bs.borrow(r2).is_none(), "cant borrow r2");
|
||||
|
||||
let mut bs = GuestBorrows::new();
|
||||
let mut bs = InnerBorrowChecker::new();
|
||||
let r1 = Region::new(2, 5);
|
||||
let r2 = Region::new(10, 5);
|
||||
let r3 = Region::new(15, 5);
|
||||
@@ -121,6 +124,23 @@ mod test {
|
||||
bs.borrow(r1).expect("can borrow r1");
|
||||
bs.borrow(r2).expect("can borrow r2");
|
||||
bs.borrow(r3).expect("can borrow r3");
|
||||
assert!(bs.borrow(r4).is_err(), "cant borrow r4");
|
||||
assert!(bs.borrow(r4).is_none(), "cant borrow r4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unborrowing() {
|
||||
let mut bs = InnerBorrowChecker::new();
|
||||
let r1 = Region::new(0, 10);
|
||||
let r2 = Region::new(10, 10);
|
||||
assert!(!r1.overlaps(r2));
|
||||
let _h1 = bs.borrow(r1).expect("can borrow r1");
|
||||
let h2 = bs.borrow(r2).expect("can borrow r2");
|
||||
|
||||
assert!(bs.borrow(r2).is_none(), "can't borrow r2 twice");
|
||||
bs.unborrow(h2);
|
||||
|
||||
let _h3 = bs
|
||||
.borrow(r2)
|
||||
.expect("can borrow r2 again now that its been unborrowed");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,7 +131,7 @@ impl<'a, T> GuestType<'a> for GuestPtr<'a, T> {
|
||||
|
||||
fn read(ptr: &GuestPtr<'a, Self>) -> Result<Self, GuestError> {
|
||||
let offset = ptr.cast::<u32>().read()?;
|
||||
Ok(GuestPtr::new(ptr.mem(), offset))
|
||||
Ok(GuestPtr::new(ptr.mem(), ptr.borrow_checker(), offset))
|
||||
}
|
||||
|
||||
fn write(ptr: &GuestPtr<'_, Self>, val: Self) -> Result<(), GuestError> {
|
||||
|
||||
@@ -16,7 +16,8 @@ mod error;
|
||||
mod guest_type;
|
||||
mod region;
|
||||
|
||||
pub use borrow::GuestBorrows;
|
||||
pub use borrow::BorrowChecker;
|
||||
use borrow::BorrowHandle;
|
||||
pub use error::GuestError;
|
||||
pub use guest_type::{GuestErrorType, GuestType, GuestTypeTransparent};
|
||||
pub use region::Region;
|
||||
@@ -150,12 +151,12 @@ pub unsafe trait GuestMemory {
|
||||
/// Note that `T` can be almost any type, and typically `offset` is a `u32`.
|
||||
/// The exception is slices and strings, in which case `offset` is a `(u32,
|
||||
/// u32)` of `(offset, length)`.
|
||||
fn ptr<'a, T>(&'a self, offset: T::Pointer) -> GuestPtr<'a, T>
|
||||
fn ptr<'a, T>(&'a self, bc: &'a BorrowChecker, offset: T::Pointer) -> GuestPtr<'a, T>
|
||||
where
|
||||
Self: Sized,
|
||||
T: ?Sized + Pointee,
|
||||
{
|
||||
GuestPtr::new(self, offset)
|
||||
GuestPtr::new(self, bc, offset)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,6 +238,7 @@ unsafe impl<T: ?Sized + GuestMemory> GuestMemory for Arc<T> {
|
||||
/// already-attached helper methods.
|
||||
pub struct GuestPtr<'a, T: ?Sized + Pointee> {
|
||||
mem: &'a (dyn GuestMemory + 'a),
|
||||
bc: &'a BorrowChecker,
|
||||
pointer: T::Pointer,
|
||||
_marker: marker::PhantomData<&'a Cell<T>>,
|
||||
}
|
||||
@@ -247,9 +249,14 @@ impl<'a, T: ?Sized + Pointee> GuestPtr<'a, T> {
|
||||
/// Note that for sized types like `u32`, `GuestPtr<T>`, etc, the `pointer`
|
||||
/// vlue is a `u32` offset into guest memory. For slices and strings,
|
||||
/// `pointer` is a `(u32, u32)` offset/length pair.
|
||||
pub fn new(mem: &'a (dyn GuestMemory + 'a), pointer: T::Pointer) -> GuestPtr<'_, T> {
|
||||
pub fn new(
|
||||
mem: &'a (dyn GuestMemory + 'a),
|
||||
bc: &'a BorrowChecker,
|
||||
pointer: T::Pointer,
|
||||
) -> GuestPtr<'a, T> {
|
||||
GuestPtr {
|
||||
mem,
|
||||
bc,
|
||||
pointer,
|
||||
_marker: marker::PhantomData,
|
||||
}
|
||||
@@ -268,6 +275,11 @@ impl<'a, T: ?Sized + Pointee> GuestPtr<'a, T> {
|
||||
self.mem
|
||||
}
|
||||
|
||||
/// Returns the borrow checker that this pointer uses
|
||||
pub fn borrow_checker(&self) -> &'a BorrowChecker {
|
||||
self.bc
|
||||
}
|
||||
|
||||
/// Casts this `GuestPtr` type to a different type.
|
||||
///
|
||||
/// This is a safe method which is useful for simply reinterpreting the type
|
||||
@@ -278,7 +290,7 @@ impl<'a, T: ?Sized + Pointee> GuestPtr<'a, T> {
|
||||
where
|
||||
T: Pointee<Pointer = u32>,
|
||||
{
|
||||
GuestPtr::new(self.mem, self.pointer)
|
||||
GuestPtr::new(self.mem, self.bc, self.pointer)
|
||||
}
|
||||
|
||||
/// Safely read a value from this pointer.
|
||||
@@ -345,7 +357,7 @@ impl<'a, T: ?Sized + Pointee> GuestPtr<'a, T> {
|
||||
Some(o) => o,
|
||||
None => return Err(GuestError::PtrOverflow),
|
||||
};
|
||||
Ok(GuestPtr::new(self.mem, offset))
|
||||
Ok(GuestPtr::new(self.mem, self.bc, offset))
|
||||
}
|
||||
|
||||
/// Returns a `GuestPtr` for an array of `T`s using this pointer as the
|
||||
@@ -354,7 +366,7 @@ impl<'a, T: ?Sized + Pointee> GuestPtr<'a, T> {
|
||||
where
|
||||
T: GuestType<'a> + Pointee<Pointer = u32>,
|
||||
{
|
||||
GuestPtr::new(self.mem, (self.pointer, elems))
|
||||
GuestPtr::new(self.mem, self.bc, (self.pointer, elems))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -403,7 +415,7 @@ impl<'a, T> GuestPtr<'a, [T]> {
|
||||
/// For safety against overlapping mutable borrows, the user must use the
|
||||
/// same `GuestBorrows` to create all `*mut str` or `*mut [T]` that are alive
|
||||
/// at the same time.
|
||||
pub fn as_raw(&self, bc: &mut GuestBorrows) -> Result<*mut [T], GuestError>
|
||||
pub fn as_slice(&self) -> Result<GuestSlice<'a, T>, GuestError>
|
||||
where
|
||||
T: GuestTypeTransparent<'a>,
|
||||
{
|
||||
@@ -415,7 +427,7 @@ impl<'a, T> GuestPtr<'a, [T]> {
|
||||
self.mem
|
||||
.validate_size_align(self.pointer.0, T::guest_align(), len)? as *mut T;
|
||||
|
||||
bc.borrow(Region {
|
||||
let borrow = self.bc.borrow(Region {
|
||||
start: self.pointer.0,
|
||||
len,
|
||||
})?;
|
||||
@@ -428,10 +440,16 @@ impl<'a, T> GuestPtr<'a, [T]> {
|
||||
|
||||
// SAFETY: iff there are no overlapping borrows (all uses of as_raw use this same
|
||||
// GuestBorrows), its valid to construct a *mut [T]
|
||||
unsafe {
|
||||
let ptr = unsafe {
|
||||
let s = slice::from_raw_parts_mut(ptr, self.pointer.1 as usize);
|
||||
Ok(s as *mut [T])
|
||||
}
|
||||
s as *mut [T]
|
||||
};
|
||||
|
||||
Ok(GuestSlice {
|
||||
ptr,
|
||||
bc: self.bc,
|
||||
borrow,
|
||||
})
|
||||
}
|
||||
|
||||
/// Copies the data pointed to by `slice` into this guest region.
|
||||
@@ -451,22 +469,20 @@ impl<'a, T> GuestPtr<'a, [T]> {
|
||||
T: GuestTypeTransparent<'a> + Copy,
|
||||
{
|
||||
// bounds check ...
|
||||
let raw = self.as_raw(&mut GuestBorrows::new())?;
|
||||
unsafe {
|
||||
// ... length check ...
|
||||
if (*raw).len() != slice.len() {
|
||||
return Err(GuestError::SliceLengthsDiffer);
|
||||
}
|
||||
// ... and copy!
|
||||
(*raw).copy_from_slice(slice);
|
||||
Ok(())
|
||||
let mut self_slice = self.as_slice()?;
|
||||
// ... length check ...
|
||||
if self_slice.len() != slice.len() {
|
||||
return Err(GuestError::SliceLengthsDiffer);
|
||||
}
|
||||
// ... and copy!
|
||||
self_slice.copy_from_slice(slice);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns a `GuestPtr` pointing to the base of the array for the interior
|
||||
/// type `T`.
|
||||
pub fn as_ptr(&self) -> GuestPtr<'a, T> {
|
||||
GuestPtr::new(self.mem, self.offset_base())
|
||||
GuestPtr::new(self.mem, self.bc, self.offset_base())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -485,7 +501,7 @@ impl<'a> GuestPtr<'a, str> {
|
||||
/// Returns a raw pointer for the underlying slice of bytes that this
|
||||
/// pointer points to.
|
||||
pub fn as_bytes(&self) -> GuestPtr<'a, [u8]> {
|
||||
GuestPtr::new(self.mem, self.pointer)
|
||||
GuestPtr::new(self.mem, self.bc, self.pointer)
|
||||
}
|
||||
|
||||
/// Attempts to read a raw `*mut str` pointer from this pointer, performing
|
||||
@@ -505,24 +521,26 @@ impl<'a> GuestPtr<'a, str> {
|
||||
/// For safety against overlapping mutable borrows, the user must use the
|
||||
/// same `GuestBorrows` to create all `*mut str` or `*mut [T]` that are
|
||||
/// alive at the same time.
|
||||
pub fn as_raw(&self, bc: &mut GuestBorrows) -> Result<*mut str, GuestError> {
|
||||
pub fn as_str(&self) -> Result<GuestStr<'a>, GuestError> {
|
||||
let ptr = self
|
||||
.mem
|
||||
.validate_size_align(self.pointer.0, 1, self.pointer.1)?;
|
||||
|
||||
bc.borrow(Region {
|
||||
let borrow = self.bc.borrow(Region {
|
||||
start: self.pointer.0,
|
||||
len: self.pointer.1,
|
||||
})?;
|
||||
|
||||
// SAFETY: iff there are no overlapping borrows (all uses of as_raw use this same
|
||||
// GuestBorrows), its valid to construct a *mut str
|
||||
unsafe {
|
||||
let s = slice::from_raw_parts_mut(ptr, self.pointer.1 as usize);
|
||||
match str::from_utf8_mut(s) {
|
||||
Ok(s) => Ok(s),
|
||||
Err(e) => Err(GuestError::InvalidUtf8(e)),
|
||||
}
|
||||
let ptr = unsafe { slice::from_raw_parts_mut(ptr, self.pointer.1 as usize) };
|
||||
match str::from_utf8_mut(ptr) {
|
||||
Ok(ptr) => Ok(GuestStr {
|
||||
ptr,
|
||||
bc: self.bc,
|
||||
borrow,
|
||||
}),
|
||||
Err(e) => Err(GuestError::InvalidUtf8(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -541,6 +559,56 @@ impl<T: ?Sized + Pointee> fmt::Debug for GuestPtr<'_, T> {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GuestSlice<'a, T> {
|
||||
ptr: *mut [T],
|
||||
bc: &'a BorrowChecker,
|
||||
borrow: BorrowHandle,
|
||||
}
|
||||
|
||||
impl<'a, T> std::ops::Deref for GuestSlice<'a, T> {
|
||||
type Target = [T];
|
||||
fn deref(&self) -> &Self::Target {
|
||||
unsafe { self.ptr.as_ref().expect("ptr guaranteed to be non-null") }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> std::ops::DerefMut for GuestSlice<'a, T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
unsafe { self.ptr.as_mut().expect("ptr guaranteed to be non-null") }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Drop for GuestSlice<'a, T> {
|
||||
fn drop(&mut self) {
|
||||
self.bc.unborrow(self.borrow)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GuestStr<'a> {
|
||||
ptr: *mut str,
|
||||
bc: &'a BorrowChecker,
|
||||
borrow: BorrowHandle,
|
||||
}
|
||||
|
||||
impl<'a> std::ops::Deref for GuestStr<'a> {
|
||||
type Target = str;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
unsafe { self.ptr.as_ref().expect("ptr guaranteed to be non-null") }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> std::ops::DerefMut for GuestStr<'a> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
unsafe { self.ptr.as_mut().expect("ptr guaranteed to be non-null") }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Drop for GuestStr<'a> {
|
||||
fn drop(&mut self) {
|
||||
self.bc.unborrow(self.borrow)
|
||||
}
|
||||
}
|
||||
|
||||
mod private {
|
||||
pub trait Sealed {}
|
||||
impl<T> Sealed for T {}
|
||||
|
||||
Reference in New Issue
Block a user