diff --git a/crates/generate/src/funcs.rs b/crates/generate/src/funcs.rs index ccf7365eab..de2c7413da 100644 --- a/crates/generate/src/funcs.rs +++ b/crates/generate/src/funcs.rs @@ -111,21 +111,25 @@ fn marshal_arg( let interface_typename = names.type_ref(&tref); let name = names.func_param(¶m.name); - let value_error_handling = if let Some(tref) = error_type { - let abi_ret = match tref.type_().passed_by() { - witx::TypePassedBy::Value(atom) => names.atom_type(atom), - _ => unreachable!("err should always be passed by value"), - }; - let err_typename = names.type_ref(&tref); - quote! { - let err: #err_typename = ::memory::GuestError::from_value_error(e, ctx); - return #abi_ret::from(err); - } - } else { - quote! { - panic!("memory error: {:?}", e) + let error_handling = |method| -> TokenStream { + if let Some(tref) = error_type { + let abi_ret = match tref.type_().passed_by() { + witx::TypePassedBy::Value(atom) => names.atom_type(atom), + _ => unreachable!("err should always be passed by value"), + }; + let err_typename = names.type_ref(&tref); + quote! { + let err: #err_typename = ::memory::GuestError::#method(e, ctx); + return #abi_ret::from(err); + } + } else { + quote! { + panic!("error: {:?}", e) + } } }; + let value_error_handling = error_handling(quote!(from_value_error)); + let memory_error_handling = error_handling(quote!(from_memory_error)); let try_into_conversion = quote! { use ::std::convert::TryInto; @@ -162,6 +166,28 @@ fn marshal_arg( }, witx::BuiltinType::String => unimplemented!("string types unimplemented"), }, - _ => unimplemented!("only enums and builtins so far"), + witx::Type::Pointer(pointee) => { + let pointee_type = names.type_ref(pointee); + quote! { + let #name = match memory.ptr_mut::<#pointee_type>(#name as u32) { + Ok(p) => p, + Err(e) => { + #memory_error_handling + } + }; + } + } + witx::Type::ConstPointer(pointee) => { + let pointee_type = names.type_ref(pointee); + quote! { + let #name = match memory.ptr::<#pointee_type>(#name as u32) { + Ok(p) => p, + Err(e) => { + #memory_error_handling + } + }; + } + } + _ => unimplemented!("argument type marshalling"), } } diff --git a/crates/generate/src/names.rs b/crates/generate/src/names.rs index 78cec1e0b8..6465e7f832 100644 --- a/crates/generate/src/names.rs +++ b/crates/generate/src/names.rs @@ -50,6 +50,14 @@ impl Names { } TypeRef::Value(ty) => match &**ty { witx::Type::Builtin(builtin) => self.builtin_type(*builtin), + witx::Type::Pointer(pointee) => { + let pointee_type = self.type_ref(&pointee); + quote!(::memory::GuestPtrMut<#pointee_type>) + } + witx::Type::ConstPointer(pointee) => { + let pointee_type = self.type_ref(&pointee); + quote!(::memory::GuestPtr<#pointee_type>) + } _ => unimplemented!("anonymous type ref"), }, } diff --git a/crates/generate/src/types.rs b/crates/generate/src/types.rs index 59b99507a5..a03b4003a2 100644 --- a/crates/generate/src/types.rs +++ b/crates/generate/src/types.rs @@ -95,12 +95,12 @@ fn define_enum(names: &Names, name: &witx::Id, e: &witx::EnumDatatype) -> TokenS } impl ::memory::GuestTypeCopy for #ident { - fn read_val(src: ::memory::GuestPtr<#ident>) -> Result<#ident, ::memory::GuestValueError> { + fn read_val>(src: &P) -> Result<#ident, ::memory::GuestValueError> { use ::std::convert::TryInto; let val = unsafe { ::std::ptr::read_unaligned(src.ptr() as *const #repr) }; val.try_into() } - fn write_val(val: #ident, dest: ::memory::GuestPtrMut<#ident>) { + fn write_val(val: #ident, dest: &::memory::GuestPtrMut<#ident>) { let val: #repr = val.into(); unsafe { ::std::ptr::write_unaligned(dest.ptr_mut() as *mut #repr, val) diff --git a/crates/memory/src/borrow.rs b/crates/memory/src/borrow.rs index 50d5656f53..c691888db9 100644 --- a/crates/memory/src/borrow.rs +++ b/crates/memory/src/borrow.rs @@ -1,59 +1,66 @@ +use std::collections::HashMap; + use crate::region::Region; +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub struct BorrowHandle(usize); + pub struct GuestBorrows { - immutable: Vec, - mutable: Vec, + immutable: HashMap, + mutable: HashMap, + next_handle: BorrowHandle, } impl GuestBorrows { pub fn new() -> Self { GuestBorrows { - immutable: Vec::new(), - mutable: Vec::new(), + immutable: HashMap::new(), + mutable: HashMap::new(), + next_handle: BorrowHandle(0), } } fn is_borrowed_immut(&self, r: Region) -> bool { - !self.immutable.iter().all(|b| !b.overlaps(r)) + !self.immutable.values().all(|b| !b.overlaps(r)) } fn is_borrowed_mut(&self, r: Region) -> bool { - !self.mutable.iter().all(|b| !b.overlaps(r)) + !self.mutable.values().all(|b| !b.overlaps(r)) } - pub fn borrow_immut(&mut self, r: Region) -> bool { + fn new_handle(&mut self) -> BorrowHandle { + let h = self.next_handle; + self.next_handle = BorrowHandle(h.0 + 1); + h + } + + pub fn borrow_immut(&mut self, r: Region) -> Option { if self.is_borrowed_mut(r) { - return false; + return None; } - self.immutable.push(r); - true + let h = self.new_handle(); + self.immutable.insert(h, r); + Some(h) } - pub fn unborrow_immut(&mut self, r: Region) { - let (ix, _) = self - .immutable - .iter() - .enumerate() - .find(|(_, reg)| r == **reg) - .expect("region exists in borrows"); - self.immutable.remove(ix); + pub fn unborrow_immut(&mut self, h: BorrowHandle) { + self.immutable + .remove(&h) + .expect("handle exists in immutable borrows"); } - pub fn borrow_mut(&mut self, r: Region) -> bool { + pub fn borrow_mut(&mut self, r: Region) -> Option { if self.is_borrowed_immut(r) || self.is_borrowed_mut(r) { - return false; + return None; } - self.mutable.push(r); - true + let h = self.new_handle(); + self.mutable.insert(h, r); + Some(h) } - pub fn unborrow_mut(&mut self, r: Region) { - let (ix, _) = self - .mutable - .iter() - .enumerate() - .find(|(_, reg)| r == **reg) - .expect("region exists in borrows"); - self.mutable.remove(ix); + pub fn unborrow_mut(&mut self, h: BorrowHandle) { + self.mutable + .remove(&h) + .expect("handle exists in mutable borrows"); } } diff --git a/crates/memory/src/guest_type.rs b/crates/memory/src/guest_type.rs index 9410e6054d..324ca5bf09 100644 --- a/crates/memory/src/guest_type.rs +++ b/crates/memory/src/guest_type.rs @@ -1,4 +1,4 @@ -use crate::{GuestPtr, GuestPtrMut, MemoryError}; +use crate::{GuestPtrMut, GuestPtrRead, MemoryError}; use thiserror::Error; pub trait GuestType: Sized { @@ -13,25 +13,25 @@ pub enum GuestValueError { } pub trait GuestTypeCopy: GuestType + Copy { - fn read_val(src: GuestPtr) -> Result; - fn write_val(val: Self, dest: GuestPtrMut); + fn read_val>(src: &P) -> Result; + fn write_val(val: Self, dest: &GuestPtrMut); } pub trait GuestTypeClone: GuestType + Clone { - fn read_ref(src: GuestPtr, dest: &mut Self) -> Result<(), GuestValueError>; - fn write_ref(val: &Self, dest: GuestPtrMut); + fn read_ref>(src: &P, dest: &mut Self) -> Result<(), GuestValueError>; + fn write_ref(val: &Self, dest: &GuestPtrMut); } impl GuestTypeClone for T where T: GuestTypeCopy, { - fn read_ref(src: GuestPtr, dest: &mut T) -> Result<(), GuestValueError> { + fn read_ref>(src: &P, dest: &mut T) -> Result<(), GuestValueError> { let val = GuestTypeCopy::read_val(src)?; *dest = val; Ok(()) } - fn write_ref(val: &T, dest: GuestPtrMut) { + fn write_ref(val: &T, dest: &GuestPtrMut) { GuestTypeCopy::write_val(*val, dest) } } @@ -49,12 +49,12 @@ macro_rules! builtin_copy { } impl GuestTypeCopy for $t { - fn read_val(src: GuestPtr<$t>) -> Result<$t, GuestValueError> { + fn read_val>(src: &P) -> Result<$t, GuestValueError> { Ok(unsafe { ::std::ptr::read_unaligned(src.ptr() as *const $t) }) } - fn write_val(val: $t, dest: GuestPtrMut<$t>) { + fn write_val(val: $t, dest: &GuestPtrMut<$t>) { unsafe { ::std::ptr::write_unaligned(dest.ptr_mut() as *mut $t, val) } diff --git a/crates/memory/src/lib.rs b/crates/memory/src/lib.rs index 872c556b15..1099776f95 100644 --- a/crates/memory/src/lib.rs +++ b/crates/memory/src/lib.rs @@ -4,5 +4,5 @@ mod memory; mod region; pub use guest_type::{GuestError, GuestType, GuestTypeClone, GuestTypeCopy, GuestValueError}; -pub use memory::{GuestMemory, GuestPtr, GuestPtrMut, MemoryError}; +pub use memory::{GuestMemory, GuestPtr, GuestPtrMut, GuestPtrRead, MemoryError}; pub use region::Region; diff --git a/crates/memory/src/memory.rs b/crates/memory/src/memory.rs index efb027550e..64fa192953 100644 --- a/crates/memory/src/memory.rs +++ b/crates/memory/src/memory.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use std::rc::Rc; use thiserror::Error; -use crate::borrow::GuestBorrows; +use crate::borrow::{BorrowHandle, GuestBorrows}; use crate::guest_type::GuestType; use crate::region::Region; @@ -31,80 +31,85 @@ impl<'a> GuestMemory<'a> { } pub fn ptr(&'a self, at: u32) -> Result, MemoryError> { - let r = Region { + let region = Region { start: at, len: T::size(), }; - let mut borrows = self.borrows.borrow_mut(); - if !self.contains(r) { - Err(MemoryError::OutOfBounds(r))?; + if !self.contains(region) { + Err(MemoryError::OutOfBounds(region))?; } - if borrows.borrow_immut(r) { + let mut borrows = self.borrows.borrow_mut(); + if let Some(handle) = borrows.borrow_immut(region) { Ok(GuestPtr { mem: &self, - region: r, + region, + handle, type_: PhantomData, }) } else { - Err(MemoryError::Borrowed(r)) + Err(MemoryError::Borrowed(region)) } } pub fn ptr_mut(&'a self, at: u32) -> Result, MemoryError> { - let r = Region { + let region = Region { start: at, len: T::size(), }; - let mut borrows = self.borrows.borrow_mut(); - if !self.contains(r) { - Err(MemoryError::OutOfBounds(r))?; + if !self.contains(region) { + Err(MemoryError::OutOfBounds(region))?; } - if borrows.borrow_mut(r) { + let mut borrows = self.borrows.borrow_mut(); + if let Some(handle) = borrows.borrow_mut(region) { Ok(GuestPtrMut { mem: &self, - region: r, + region, + handle, type_: PhantomData, }) } else { - Err(MemoryError::Borrowed(r)) + Err(MemoryError::Borrowed(region)) } } } +pub trait GuestPtrRead { + fn ptr(&self) -> *const u8; +} + pub struct GuestPtr<'a, T> { mem: &'a GuestMemory<'a>, region: Region, + handle: BorrowHandle, type_: PhantomData, } -impl<'a, T: GuestType> GuestPtr<'a, T> { - pub fn ptr(&self) -> *const u8 { +impl<'a, T: GuestType> GuestPtrRead for GuestPtr<'a, T> { + fn ptr(&self) -> *const u8 { (self.mem.ptr as usize + self.region.start as usize) as *const u8 } - - pub unsafe fn downcast(self) -> GuestPtr<'a, Q> { - debug_assert!(T::size() == Q::size(), "downcast to type of same size"); - GuestPtr { - mem: self.mem, - region: self.region, - type_: PhantomData, - } - } } impl<'a, T> Drop for GuestPtr<'a, T> { fn drop(&mut self) { let mut borrows = self.mem.borrows.borrow_mut(); - borrows.unborrow_immut(self.region); + borrows.unborrow_immut(self.handle); } } pub struct GuestPtrMut<'a, T> { mem: &'a GuestMemory<'a>, region: Region, + handle: BorrowHandle, type_: PhantomData, } +impl<'a, T: GuestType> GuestPtrRead for GuestPtrMut<'a, T> { + fn ptr(&self) -> *const u8 { + (self.mem.ptr as usize + self.region.start as usize) as *const u8 + } +} + impl<'a, T> GuestPtrMut<'a, T> { pub fn ptr_mut(&self) -> *mut u8 { (self.mem.ptr as usize + self.region.start as usize) as *mut u8 @@ -113,7 +118,7 @@ impl<'a, T> GuestPtrMut<'a, T> { impl<'a, T> Drop for GuestPtrMut<'a, T> { fn drop(&mut self) { let mut borrows = self.mem.borrows.borrow_mut(); - borrows.unborrow_mut(self.region); + borrows.unborrow_mut(self.handle); } } diff --git a/src/lib.rs b/src/lib.rs index f6500df8c0..2f8f782965 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,8 +12,29 @@ pub mod test { println!("BAR: {} {}", an_int, an_float); Ok(()) } - fn baz(&mut self, excuse: types::Excuse) -> Result<(), types::Errno> { - println!("BAZ: {:?}", excuse); + fn baz( + &mut self, + excuse: types::Excuse, + a_better_excuse_by_reference: ::memory::GuestPtrMut, + a_lamer_excuse_by_reference: ::memory::GuestPtr, + ) -> Result<(), types::Errno> { + use memory::GuestTypeCopy; + let a_better_excuse = + types::Excuse::read_val(&a_better_excuse_by_reference).map_err(|val_err| { + eprintln!("a_better_excuse_by_reference value error: {:?}", val_err); + types::Errno::InvalidArg + })?; + let a_lamer_excuse = + types::Excuse::read_val(&a_lamer_excuse_by_reference).map_err(|val_err| { + eprintln!("a_lamer_excuse_by_reference value error: {:?}", val_err); + types::Errno::InvalidArg + })?; + types::Excuse::write_val(a_lamer_excuse, &a_better_excuse_by_reference); + + println!( + "BAZ: {:?} {:?} {:?}", + excuse, a_better_excuse, a_lamer_excuse + ); Ok(()) } } diff --git a/test.witx b/test.witx index 9611b9203b..d8e8848cc5 100644 --- a/test.witx +++ b/test.witx @@ -21,5 +21,7 @@ (result $error $errno)) (@interface func (export "baz") (param $an_excuse $excuse) + (param $an_excuse_by_reference (@witx pointer $excuse)) + (param $a_lamer_excuse (@witx const_pointer $excuse)) (result $error $errno)) )