diff --git a/crates/generate/src/funcs.rs b/crates/generate/src/funcs.rs index de2c7413da..f624e15146 100644 --- a/crates/generate/src/funcs.rs +++ b/crates/generate/src/funcs.rs @@ -97,7 +97,7 @@ pub fn define_func(names: &Names, func: &witx::InterfaceFunc) -> TokenStream { Err(e) => { return #err_val; }, }; #(#marshal_rets)* - let success:#err_type = ::memory::GuestError::success(); + let success:#err_type = ::memory::GuestErrorType::success(); #abi_ret::from(success) }) } @@ -111,7 +111,7 @@ fn marshal_arg( let interface_typename = names.type_ref(&tref); let name = names.func_param(¶m.name); - let error_handling = |method| -> TokenStream { + let error_handling: TokenStream = { if let Some(tref) = error_type { let abi_ret = match tref.type_().passed_by() { witx::TypePassedBy::Value(atom) => names.atom_type(atom), @@ -119,7 +119,7 @@ fn marshal_arg( }; let err_typename = names.type_ref(&tref); quote! { - let err: #err_typename = ::memory::GuestError::#method(e, ctx); + let err: #err_typename = ::memory::GuestErrorType::from_error(e, ctx); return #abi_ret::from(err); } } else { @@ -128,15 +128,13 @@ fn marshal_arg( } } }; - let value_error_handling = error_handling(quote!(from_value_error)); - let memory_error_handling = error_handling(quote!(from_memory_error)); let try_into_conversion = quote! { use ::std::convert::TryInto; let #name: #interface_typename = match #name.try_into() { Ok(a) => a, Err(e) => { - #value_error_handling + #error_handling } }; }; @@ -151,7 +149,7 @@ fn marshal_arg( let #name: #interface_typename = match (#name as i32).try_into() { Ok(a) => a, Err(e) => { - #value_error_handling + #error_handling } } }, @@ -172,7 +170,7 @@ fn marshal_arg( let #name = match memory.ptr_mut::<#pointee_type>(#name as u32) { Ok(p) => p, Err(e) => { - #memory_error_handling + #error_handling } }; } @@ -183,7 +181,7 @@ fn marshal_arg( let #name = match memory.ptr::<#pointee_type>(#name as u32) { Ok(p) => p, Err(e) => { - #memory_error_handling + #error_handling } }; } diff --git a/crates/generate/src/types.rs b/crates/generate/src/types.rs index a03b4003a2..d385ae3715 100644 --- a/crates/generate/src/types.rs +++ b/crates/generate/src/types.rs @@ -55,18 +55,18 @@ fn define_enum(names: &Names, name: &witx::Id, e: &witx::EnumDatatype) -> TokenS } impl ::std::convert::TryFrom<#repr> for #ident { - type Error = ::memory::GuestValueError; - fn try_from(value: #repr) -> Result<#ident, ::memory::GuestValueError> { + type Error = ::memory::GuestError; + fn try_from(value: #repr) -> Result<#ident, ::memory::GuestError> { match value as usize { #(#tryfrom_repr_cases),*, - _ => Err(::memory::GuestValueError::InvalidEnum(stringify!(#ident))), + _ => Err(::memory::GuestError::InvalidEnumValue(stringify!(#ident))), } } } impl ::std::convert::TryFrom<#abi_repr> for #ident { - type Error = ::memory::GuestValueError; - fn try_from(value: #abi_repr) -> Result<#ident, ::memory::GuestValueError> { + type Error = ::memory::GuestError; + fn try_from(value: #abi_repr) -> Result<#ident, ::memory::GuestError> { #ident::try_from(value as #repr) } } @@ -95,7 +95,7 @@ fn define_enum(names: &Names, name: &witx::Id, e: &witx::EnumDatatype) -> TokenS } impl ::memory::GuestTypeCopy for #ident { - fn read_val>(src: &P) -> Result<#ident, ::memory::GuestValueError> { + fn read_val<'a, P: ::memory::GuestPtrRead<'a, #ident>>(src: &P) -> Result<#ident, ::memory::GuestError> { use ::std::convert::TryInto; let val = unsafe { ::std::ptr::read_unaligned(src.ptr() as *const #repr) }; val.try_into() diff --git a/crates/memory/src/error.rs b/crates/memory/src/error.rs new file mode 100644 index 0000000000..7c5632ca05 --- /dev/null +++ b/crates/memory/src/error.rs @@ -0,0 +1,12 @@ +use crate::Region; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum GuestError { + #[error("Invalid enum value {0}")] + InvalidEnumValue(&'static str), + #[error("Out of bounds: {0:?}")] + PtrOutOfBounds(Region), + #[error("Borrowed: {0:?}")] + PtrBorrowed(Region), +} diff --git a/crates/memory/src/guest_type.rs b/crates/memory/src/guest_type.rs index 324ca5bf09..6233888540 100644 --- a/crates/memory/src/guest_type.rs +++ b/crates/memory/src/guest_type.rs @@ -1,24 +1,17 @@ -use crate::{GuestPtrMut, GuestPtrRead, MemoryError}; -use thiserror::Error; +use crate::{GuestError, GuestPtrMut, GuestPtrRead}; pub trait GuestType: Sized { fn size() -> u32; fn name() -> &'static str; } -#[derive(Debug, Error)] -pub enum GuestValueError { - #[error("Invalid enum {0}")] - InvalidEnum(&'static str), -} - pub trait GuestTypeCopy: GuestType + Copy { - fn read_val>(src: &P) -> Result; + fn read_val<'a, P: GuestPtrRead<'a, Self>>(src: &P) -> Result; fn write_val(val: Self, dest: &GuestPtrMut); } pub trait GuestTypeClone: GuestType + Clone { - fn read_ref>(src: &P, dest: &mut Self) -> Result<(), GuestValueError>; + fn read_ref<'a, P: GuestPtrRead<'a, Self>>(src: &P, dest: &mut Self) -> Result<(), GuestError>; fn write_ref(val: &Self, dest: &GuestPtrMut); } @@ -26,7 +19,7 @@ impl GuestTypeClone for T where T: GuestTypeCopy, { - fn read_ref>(src: &P, dest: &mut T) -> Result<(), GuestValueError> { + fn read_ref<'a, P: GuestPtrRead<'a, Self>>(src: &P, dest: &mut T) -> Result<(), GuestError> { let val = GuestTypeCopy::read_val(src)?; *dest = val; Ok(()) @@ -49,7 +42,7 @@ macro_rules! builtin_copy { } impl GuestTypeCopy for $t { - fn read_val>(src: &P) -> Result<$t, GuestValueError> { + fn read_val<'a, P: GuestPtrRead<'a, $t>>(src: &P) -> Result<$t, GuestError> { Ok(unsafe { ::std::ptr::read_unaligned(src.ptr() as *const $t) }) @@ -67,9 +60,8 @@ macro_rules! builtin_copy { // These definitions correspond to all the witx BuiltinType variants that are Copy: builtin_copy!(u8, i8, u16, i16, u32, i32, u64, i64, f32, f64, usize, char); -pub trait GuestError { +pub trait GuestErrorType { type Context; fn success() -> Self; - fn from_memory_error(memory_error: MemoryError, ctx: &mut Self::Context) -> Self; - fn from_value_error(value_error: GuestValueError, ctx: &mut Self::Context) -> Self; + fn from_error(e: GuestError, ctx: &mut Self::Context) -> Self; } diff --git a/crates/memory/src/lib.rs b/crates/memory/src/lib.rs index 1099776f95..f8705d6c80 100644 --- a/crates/memory/src/lib.rs +++ b/crates/memory/src/lib.rs @@ -1,8 +1,10 @@ mod borrow; +mod error; mod guest_type; mod memory; mod region; -pub use guest_type::{GuestError, GuestType, GuestTypeClone, GuestTypeCopy, GuestValueError}; -pub use memory::{GuestMemory, GuestPtr, GuestPtrMut, GuestPtrRead, MemoryError}; +pub use error::GuestError; +pub use guest_type::{GuestErrorType, GuestType, GuestTypeClone, GuestTypeCopy}; +pub use memory::{GuestMemory, GuestPtr, GuestPtrMut, GuestPtrRead}; pub use region::Region; diff --git a/crates/memory/src/memory.rs b/crates/memory/src/memory.rs index 64fa192953..3af1a0413d 100644 --- a/crates/memory/src/memory.rs +++ b/crates/memory/src/memory.rs @@ -1,11 +1,9 @@ use std::cell::RefCell; use std::marker::PhantomData; use std::rc::Rc; -use thiserror::Error; use crate::borrow::{BorrowHandle, GuestBorrows}; -use crate::guest_type::GuestType; -use crate::region::Region; +use crate::{GuestError, GuestType, Region}; pub struct GuestMemory<'a> { ptr: *mut u8, @@ -30,13 +28,13 @@ impl<'a> GuestMemory<'a> { && r.start < (self.len - r.len) } - pub fn ptr(&'a self, at: u32) -> Result, MemoryError> { + pub fn ptr(&'a self, at: u32) -> Result, GuestError> { let region = Region { start: at, len: T::size(), }; if !self.contains(region) { - Err(MemoryError::OutOfBounds(region))?; + Err(GuestError::PtrOutOfBounds(region))?; } let mut borrows = self.borrows.borrow_mut(); if let Some(handle) = borrows.borrow_immut(region) { @@ -47,17 +45,17 @@ impl<'a> GuestMemory<'a> { type_: PhantomData, }) } else { - Err(MemoryError::Borrowed(region)) + Err(GuestError::PtrBorrowed(region)) } } - pub fn ptr_mut(&'a self, at: u32) -> Result, MemoryError> { + pub fn ptr_mut(&'a self, at: u32) -> Result, GuestError> { let region = Region { start: at, len: T::size(), }; if !self.contains(region) { - Err(MemoryError::OutOfBounds(region))?; + Err(GuestError::PtrOutOfBounds(region))?; } let mut borrows = self.borrows.borrow_mut(); if let Some(handle) = borrows.borrow_mut(region) { @@ -68,13 +66,19 @@ impl<'a> GuestMemory<'a> { type_: PhantomData, }) } else { - Err(MemoryError::Borrowed(region)) + Err(GuestError::PtrBorrowed(region)) } } } -pub trait GuestPtrRead { - fn ptr(&self) -> *const u8; +/// These methods should not be used by the end user - just by implementations of the +/// GuestValueClone and GuestValueCopy traits! +pub trait GuestPtrRead<'a, T> { + fn mem(&self) -> &'a GuestMemory<'a>; + fn region(&self) -> &Region; + fn ptr(&self) -> *const u8 { + (self.mem().ptr as usize + self.region().start as usize) as *const u8 + } } pub struct GuestPtr<'a, T> { @@ -84,9 +88,31 @@ pub struct GuestPtr<'a, T> { type_: PhantomData, } -impl<'a, T: GuestType> GuestPtrRead for GuestPtr<'a, T> { - fn ptr(&self) -> *const u8 { - (self.mem.ptr as usize + self.region.start as usize) as *const u8 +impl<'a, T> GuestPtrRead<'a, T> for GuestPtr<'a, T> { + fn mem(&self) -> &'a GuestMemory<'a> { + self.mem + } + fn region(&self) -> &Region { + &self.region + } +} + +impl<'a, T> GuestType for GuestPtr<'a, T> { + fn size() -> u32 { + 4 + } + fn name() -> &'static str { + "GuestPtr<...>" + } +} + +impl<'a, T: GuestType> GuestPtr<'a, T> { + pub fn read_ptr>(src: &P) -> Result { + let raw_ptr = unsafe { ::std::ptr::read_unaligned(src.ptr() as *const u32) }; + src.mem().ptr(raw_ptr) + } + pub fn write_ptr(ptr: &Self, dest: &GuestPtrMut) { + unsafe { ::std::ptr::write_unaligned(dest.ptr_mut() as *mut u32, ptr.region.start) } } } @@ -104,9 +130,12 @@ pub struct GuestPtrMut<'a, T> { type_: PhantomData, } -impl<'a, T: GuestType> GuestPtrRead for GuestPtrMut<'a, T> { - fn ptr(&self) -> *const u8 { - (self.mem.ptr as usize + self.region.start as usize) as *const u8 +impl<'a, T> GuestPtrRead<'a, T> for GuestPtrMut<'a, T> { + fn mem(&self) -> &'a GuestMemory<'a> { + self.mem + } + fn region(&self) -> &Region { + &self.region } } @@ -115,6 +144,7 @@ impl<'a, T> GuestPtrMut<'a, T> { (self.mem.ptr as usize + self.region.start as usize) as *mut u8 } } + impl<'a, T> Drop for GuestPtrMut<'a, T> { fn drop(&mut self) { let mut borrows = self.mem.borrows.borrow_mut(); @@ -122,10 +152,29 @@ impl<'a, T> Drop for GuestPtrMut<'a, T> { } } -#[derive(Debug, Error)] -pub enum MemoryError { - #[error("Out of bounds: {0:?}")] - OutOfBounds(Region), - #[error("Borrowed: {0:?}")] - Borrowed(Region), +impl<'a, T> GuestType for GuestPtrMut<'a, T> { + fn size() -> u32 { + 4 + } + fn name() -> &'static str { + "GuestPtrMut<...>" + } +} + +impl<'a, T: GuestType> GuestPtrMut<'a, T> { + pub fn read_ptr>(src: &P) -> Result { + let raw_ptr = unsafe { ::std::ptr::read_unaligned(src.ptr() as *const u32) }; + src.mem().ptr_mut(raw_ptr) + } + pub fn write_ptr(ptr: &Self, dest: &GuestPtrMut) { + unsafe { ::std::ptr::write_unaligned(dest.ptr_mut() as *mut u32, ptr.region.start) } + } + + pub fn as_immut(self) -> GuestPtr<'a, T> { + let mem = self.mem; + let start = self.region.start; + drop(self); + mem.ptr(start) + .expect("can borrow just-dropped mutable region as immut") + } } diff --git a/src/lib.rs b/src/lib.rs index 2f8f782965..fd287a0f0c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,8 +3,7 @@ pub mod test { generate::from_witx!("test.witx"); pub struct WasiCtx { - mem_errors: Vec<::memory::MemoryError>, - value_errors: Vec<::memory::GuestValueError>, + guest_errors: Vec<::memory::GuestError>, } impl foo::Foo for WasiCtx { @@ -17,43 +16,65 @@ pub mod test { excuse: types::Excuse, a_better_excuse_by_reference: ::memory::GuestPtrMut, a_lamer_excuse_by_reference: ::memory::GuestPtr, + two_layers_of_excuses: ::memory::GuestPtrMut<::memory::GuestPtr>, ) -> Result<(), types::Errno> { use memory::GuestTypeCopy; + + // Read enum value from mutable: let a_better_excuse = - types::Excuse::read_val(&a_better_excuse_by_reference).map_err(|val_err| { - eprintln!("a_better_excuse_by_reference value error: {:?}", val_err); + types::Excuse::read_val(&a_better_excuse_by_reference).map_err(|e| { + eprintln!("a_better_excuse_by_reference error: {}", e); types::Errno::InvalidArg })?; + + // Read enum value from immutable ptr: let a_lamer_excuse = - types::Excuse::read_val(&a_lamer_excuse_by_reference).map_err(|val_err| { - eprintln!("a_lamer_excuse_by_reference value error: {:?}", val_err); + types::Excuse::read_val(&a_lamer_excuse_by_reference).map_err(|e| { + eprintln!("a_lamer_excuse_by_reference error: {}", e); types::Errno::InvalidArg })?; + + // Write enum to mutable ptr: types::Excuse::write_val(a_lamer_excuse, &a_better_excuse_by_reference); + // Read ptr value from mutable ptr: + let one_layer_down = + ::memory::GuestPtr::read_ptr(&two_layers_of_excuses).map_err(|e| { + eprintln!("one_layer_down error: {}", e); + types::Errno::InvalidArg + })?; + + // Read enum value from that ptr: + let two_layers_down = types::Excuse::read_val(&one_layer_down).map_err(|e| { + eprintln!("two_layers_down error: {}", e); + types::Errno::InvalidArg + })?; + + // Write ptr value to mutable ptr: + ::memory::GuestPtr::write_ptr( + &a_better_excuse_by_reference.as_immut(), + &two_layers_of_excuses, + ); + println!( - "BAZ: {:?} {:?} {:?}", - excuse, a_better_excuse, a_lamer_excuse + "BAZ: excuse: {:?}, better excuse: {:?}, lamer excuse: {:?}, two layers down: {:?}", + excuse, a_better_excuse, a_lamer_excuse, two_layers_down ); Ok(()) } } // Errno is used as a first return value in the functions above, therefore - // it must implement GuestError with type Context = WasiCtx. + // it must implement GuestErrorType with type Context = WasiCtx. // The context type should let you do logging or debugging or whatever you need // with these errors. We just push them to vecs. - impl ::memory::GuestError for types::Errno { + impl ::memory::GuestErrorType for types::Errno { type Context = WasiCtx; fn success() -> types::Errno { types::Errno::Ok } - fn from_memory_error(e: ::memory::MemoryError, ctx: &mut WasiCtx) -> types::Errno { - ctx.mem_errors.push(e); - types::Errno::InvalidArg - } - fn from_value_error(e: ::memory::GuestValueError, ctx: &mut WasiCtx) -> types::Errno { - ctx.value_errors.push(e); + fn from_error(e: ::memory::GuestError, ctx: &mut WasiCtx) -> types::Errno { + ctx.guest_errors.push(e); types::Errno::InvalidArg } } diff --git a/test.witx b/test.witx index d8e8848cc5..16964d8eea 100644 --- a/test.witx +++ b/test.witx @@ -23,5 +23,6 @@ (param $an_excuse $excuse) (param $an_excuse_by_reference (@witx pointer $excuse)) (param $a_lamer_excuse (@witx const_pointer $excuse)) + (param $two_layers_of_excuses (@witx pointer (@witx const_pointer $excuse))) (result $error $errno)) )