diff --git a/crates/generate/Cargo.toml b/crates/generate/Cargo.toml index cd720d549c..2b89101bd3 100644 --- a/crates/generate/Cargo.toml +++ b/crates/generate/Cargo.toml @@ -8,6 +8,7 @@ edition = "2018" proc-macro = true [dependencies] +memory = { path = "../memory" } witx = { path = "../WASI/tools/witx" } quote = "1.0" proc-macro2 = "1.0" diff --git a/crates/generate/src/imp.rs b/crates/generate/src/imp.rs deleted file mode 100644 index bd980c139b..0000000000 --- a/crates/generate/src/imp.rs +++ /dev/null @@ -1,266 +0,0 @@ -use heck::{CamelCase, MixedCase, ShoutySnakeCase}; -use proc_macro2::{Delimiter, Group, Literal, TokenStream, TokenTree}; -use quote::{format_ident, quote}; -use std::convert::TryFrom; - -#[derive(Copy, Clone, Debug, Eq, PartialEq)] -pub enum Mode { - Host, - Wasi32, - Wasi, -} - -impl Mode { - pub fn include_target_types(&self) -> bool { - match self { - Mode::Host | Mode::Wasi32 => true, - Mode::Wasi => false, - } - } -} - -pub fn gen(doc: witx::Document) -> TokenStream { - let mut output = TokenStream::new(); - gen_datatypes(&mut output, &doc, Mode::Wasi); - // gen_datatypes(&mut output, &doc, Mode::Wasi32); - // gen_datatypes(&mut output, &doc, Mode::Host); - - output -} - -fn gen_datatypes(output: &mut TokenStream, doc: &witx::Document, mode: Mode) { - for namedtype in doc.typenames() { - if mode.include_target_types() != namedtype_has_target_size(&namedtype) { - continue; - } - - gen_datatype(output, mode, &namedtype); - } -} - -fn gen_datatype(output: &mut TokenStream, mode: Mode, namedtype: &witx::NamedType) { - let wasi_name = format_ident!("{}", namedtype.name.as_str().to_camel_case()); - match &namedtype.tref { - witx::TypeRef::Name(alias_to) => { - let to = tref_tokens(mode, &alias_to.tref); - output.extend(quote!(pub type #wasi_name = #to;)); - } - witx::TypeRef::Value(v) => match &**v { - witx::Type::Enum(e) => { - let repr = int_repr_tokens(e.repr); - output.extend(quote!(#[repr(#repr)])); - output - .extend(quote!(#[derive(Copy, Clone, Debug, std::hash::Hash, Eq, PartialEq)])); - - let mut inner = TokenStream::new(); - for variant in &e.variants { - let value_name = if namedtype.name.as_str() == "errno" { - // FIXME discussion point! - format_ident!("E{}", variant.name.as_str().to_mixed_case()) - } else { - format_ident!("{}", variant.name.as_str().to_camel_case()) - }; - inner.extend(quote!(#value_name,)); - } - - output.extend(quote!(pub enum #wasi_name { - #inner - })); - } - witx::Type::Int(_) => {} // TODO - witx::Type::Flags(f) => { - let repr = int_repr_tokens(f.repr); - output.extend(quote!(#[repr(transparent)])); - output - .extend(quote!(#[derive(Copy, Clone, Debug, std::hash::Hash, Eq, PartialEq)])); - output.extend(quote!(pub struct #wasi_name(#repr);)); - // TODO - // Since `Flags` are represented by a "transparent" struct, we should probably - // auto-generate `from_raw(raw: #repr)` method or similar - - let mut inner = TokenStream::new(); - for (index, flag) in f.flags.iter().enumerate() { - let value_name = format_ident!("{}", flag.name.as_str().to_shouty_snake_case()); - let flag_value = Literal::u128_unsuffixed( - 1u128 - .checked_shl(u32::try_from(index).expect("flag value overflow")) - .expect("flag value overflow"), - ); - inner.extend( - quote!(pub const #value_name: #wasi_name = #wasi_name(#flag_value);), - ); - } - - output.extend(quote!(impl #wasi_name { - #inner - })); - } - witx::Type::Struct(s) => { - output.extend(quote!(#[repr(C)])); - // Types which contain unions can't trivially implement Debug, - // Hash, or Eq, because the type itself doesn't record which - // union member is active. - if struct_has_union(&s) { - output.extend(quote!(#[derive(Copy, Clone)])); - output.extend(quote!(#[allow(missing_debug_implementations)])); - } else { - output.extend(quote!(#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)])); - } - - output.extend(quote!(pub struct #wasi_name)); - - let mut inner = TokenStream::new(); - for member in &s.members { - let member_name = format_ident!("r#{}", member.name.as_str()); - let member_type = tref_tokens(mode, &member.tref); - inner.extend(quote!(pub #member_name: #member_type,)); - } - let braced = Group::new(Delimiter::Brace, inner); - output.extend(TokenStream::from(TokenTree::Group(braced))); - } - witx::Type::Union(u) => { - output.extend(quote!(#[repr(C)])); - output.extend(quote!(#[derive(Copy, Clone)])); - output.extend(quote!(#[allow(missing_debug_implementations)])); - - output.extend(quote!(pub union #wasi_name)); - - let mut inner = TokenStream::new(); - for variant in &u.variants { - let variant_name = format_ident!("r#{}", variant.name.as_str()); - let variant_type = tref_tokens(mode, &variant.tref); - inner.extend(quote!(pub #variant_name: #variant_type,)); - } - let braced = Group::new(Delimiter::Brace, inner); - output.extend(TokenStream::from(TokenTree::Group(braced))); - } - witx::Type::Handle(_h) => { - output.extend(quote!(pub type #wasi_name = u32;)); - } - witx::Type::Builtin(b) => { - if namedtype.name.as_str() == "size" { - match mode { - Mode::Host => output.extend(quote!(pub type #wasi_name = usize;)), - Mode::Wasi => panic!("size has target-specific size"), - Mode::Wasi32 => output.extend(quote!(pub type #wasi_name = u32;)), - } - } else { - let b_type = builtin_tokens(mode, *b); - output.extend(quote!(pub type #wasi_name = #b_type;)); - } - } - witx::Type::Pointer { .. } - | witx::Type::ConstPointer { .. } - | witx::Type::Array { .. } => { - let tref_tokens = tref_tokens(mode, &namedtype.tref); - output.extend(quote!(pub type #wasi_name = #tref_tokens;)); - } - }, - } -} - -fn int_repr_tokens(int_repr: witx::IntRepr) -> TokenStream { - match int_repr { - witx::IntRepr::U8 => quote!(u8), - witx::IntRepr::U16 => quote!(u16), - witx::IntRepr::U32 => quote!(u32), - witx::IntRepr::U64 => quote!(u64), - } -} - -fn builtin_tokens(mode: Mode, builtin: witx::BuiltinType) -> TokenStream { - match builtin { - witx::BuiltinType::String => match mode { - Mode::Host => quote!((*const u8, usize)), - Mode::Wasi => panic!("strings have target-specific size"), - Mode::Wasi32 => quote!((u32, u32)), - }, - witx::BuiltinType::U8 => quote!(u8), - witx::BuiltinType::U16 => quote!(u16), - witx::BuiltinType::U32 => quote!(u32), - witx::BuiltinType::U64 => quote!(u64), - witx::BuiltinType::S8 => quote!(i8), - witx::BuiltinType::S16 => quote!(i16), - witx::BuiltinType::S32 => quote!(i32), - witx::BuiltinType::S64 => quote!(i64), - witx::BuiltinType::F32 => quote!(f32), - witx::BuiltinType::F64 => quote!(f64), - witx::BuiltinType::Char8 => quote!(char), - witx::BuiltinType::USize => quote!(usize), - } -} - -fn tref_tokens(mode: Mode, tref: &witx::TypeRef) -> TokenStream { - match tref { - witx::TypeRef::Name(n) => TokenStream::from(TokenTree::Ident(format_ident!( - "{}", - n.name.as_str().to_camel_case() - ))), - witx::TypeRef::Value(v) => match &**v { - witx::Type::Builtin(b) => builtin_tokens(mode, *b), - witx::Type::Pointer(pointee) => { - let pointee = tref_tokens(mode, pointee); - match mode { - Mode::Host => quote!(*mut #pointee), - Mode::Wasi => panic!("pointers have target-specific size"), - Mode::Wasi32 => quote!(u32), - } - } - witx::Type::ConstPointer(pointee) => { - let pointee = tref_tokens(mode, pointee); - match mode { - Mode::Host => quote!(*const #pointee), - Mode::Wasi => panic!("pointers have target-specific size"), - Mode::Wasi32 => quote!(u32), - } - } - witx::Type::Array(element) => { - let element_name = tref_tokens(mode, element); - match mode { - Mode::Host => quote!((*const #element_name, usize)), - Mode::Wasi => panic!("arrays have target-specific size"), - Mode::Wasi32 => quote!((u32, u32)), - } - } - t => panic!("cannot give name to anonymous type {:?}", t), - }, - } -} - -/// Test whether the given struct contains any union members. -fn struct_has_union(s: &witx::StructDatatype) -> bool { - s.members.iter().any(|member| match &*member.tref.type_() { - witx::Type::Union { .. } => true, - witx::Type::Struct(s) => struct_has_union(&s), - _ => false, - }) -} - -/// Test whether the type referred to has a target-specific size. -fn tref_has_target_size(tref: &witx::TypeRef) -> bool { - match tref { - witx::TypeRef::Name(nt) => namedtype_has_target_size(&nt), - witx::TypeRef::Value(t) => type_has_target_size(&t), - } -} - -/// Test whether the given named type has a target-specific size. -fn namedtype_has_target_size(nt: &witx::NamedType) -> bool { - if nt.name.as_str() == "size" { - true - } else { - tref_has_target_size(&nt.tref) - } -} - -/// Test whether the given type has a target-specific size. -fn type_has_target_size(ty: &witx::Type) -> bool { - match ty { - witx::Type::Builtin(witx::BuiltinType::String) => true, - witx::Type::Pointer { .. } | witx::Type::ConstPointer { .. } => true, - witx::Type::Array(elem) => tref_has_target_size(elem), - witx::Type::Struct(s) => s.members.iter().any(|m| tref_has_target_size(&m.tref)), - witx::Type::Union(u) => u.variants.iter().any(|v| tref_has_target_size(&v.tref)), - _ => false, - } -} diff --git a/crates/generate/src/lib.rs b/crates/generate/src/lib.rs index f3b9a93f5f..f78a484153 100644 --- a/crates/generate/src/lib.rs +++ b/crates/generate/src/lib.rs @@ -1,16 +1,36 @@ extern crate proc_macro; -mod imp; mod parse; +mod types; +use heck::SnakeCase; use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; +use quote::{format_ident, quote}; +use types::define_datatype; #[proc_macro] pub fn from_witx(args: TokenStream) -> TokenStream { let args = TokenStream2::from(args); let witx_paths = parse::witx_paths(args).expect("parsing macro arguments"); let doc = witx::load(&witx_paths).expect("loading witx"); - let out = imp::gen(doc); - TokenStream::from(out) + + let mut types = TokenStream2::new(); + for namedtype in doc.typenames() { + let def = define_datatype(&namedtype); + types.extend(def); + } + + let mut modules = TokenStream2::new(); + for module in doc.modules() { + let modname = format_ident!("{}", module.name.as_str().to_snake_case()); + let mut fs = TokenStream2::new(); + for func in module.funcs() { + let ident = format_ident!("{}", func.name.as_str().to_snake_case()); + fs.extend(quote!(pub fn #ident() { unimplemented!() })); + } + modules.extend(quote!(mod #modname { use super::types::*; #fs })); + } + + TokenStream::from(quote!(mod types { #types } #modules)) } diff --git a/crates/generate/src/types.rs b/crates/generate/src/types.rs new file mode 100644 index 0000000000..f9c98b9234 --- /dev/null +++ b/crates/generate/src/types.rs @@ -0,0 +1,82 @@ +use heck::{CamelCase, MixedCase}; +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; + +pub fn define_datatype(namedtype: &witx::NamedType) -> TokenStream { + match &namedtype.tref { + witx::TypeRef::Name(alias_to) => define_alias(&namedtype.name, &alias_to), + witx::TypeRef::Value(v) => match &**v { + witx::Type::Enum(e) => define_enum(&namedtype.name, &e), + witx::Type::Int(_) => unimplemented!("int types"), + witx::Type::Flags(_) => unimplemented!("flag types"), + witx::Type::Struct(_) => unimplemented!("struct types"), + witx::Type::Union(_) => unimplemented!("union types"), + witx::Type::Handle(_h) => unimplemented!("handle types"), + witx::Type::Builtin(b) => define_builtin(&namedtype.name, &b), + witx::Type::Pointer { .. } => unimplemented!("pointer types"), + witx::Type::ConstPointer { .. } => unimplemented!("constpointer types"), + witx::Type::Array { .. } => unimplemented!("array types"), + }, + } +} + +fn define_alias(name: &witx::Id, to: &witx::NamedType) -> TokenStream { + let ident = format_ident!("{}", name.as_str().to_camel_case()); + let to = format_ident!("{}", to.name.as_str().to_camel_case()); + + quote!(pub type #ident = #to;) +} + +fn define_enum(name: &witx::Id, e: &witx::EnumDatatype) -> TokenStream { + let ident = format_ident!("{}", name.as_str().to_camel_case()); + let mut output = TokenStream::new(); + let repr = int_repr_tokens(e.repr); + output.extend(quote!(#[repr(#repr)])); + output.extend(quote!(#[derive(Copy, Clone, Debug, std::hash::Hash, Eq, PartialEq)])); + + let mut inner = TokenStream::new(); + for variant in &e.variants { + let value_name = if name.as_str() == "errno" { + // FIXME discussion point! + format_ident!("E{}", variant.name.as_str().to_mixed_case()) + } else { + format_ident!("{}", variant.name.as_str().to_camel_case()) + }; + inner.extend(quote!(#value_name,)); + } + + output.extend(quote!(pub enum #ident { + #inner + })); + + output +} + +fn define_builtin(name: &witx::Id, builtin: &witx::BuiltinType) -> TokenStream { + let ident = format_ident!("{}", name.as_str().to_camel_case()); + let prim = match builtin { + witx::BuiltinType::String => quote!(String), + witx::BuiltinType::U8 => quote!(u8), + witx::BuiltinType::U16 => quote!(u16), + witx::BuiltinType::U32 => quote!(u32), + witx::BuiltinType::U64 => quote!(u64), + witx::BuiltinType::S8 => quote!(i8), + witx::BuiltinType::S16 => quote!(i16), + witx::BuiltinType::S32 => quote!(i32), + witx::BuiltinType::S64 => quote!(i64), + witx::BuiltinType::F32 => quote!(f32), + witx::BuiltinType::F64 => quote!(f64), + witx::BuiltinType::Char8 => quote!(char), + witx::BuiltinType::USize => quote!(usize), + }; + quote!(pub type #ident = #prim;) +} + +fn int_repr_tokens(int_repr: witx::IntRepr) -> TokenStream { + match int_repr { + witx::IntRepr::U8 => quote!(u8), + witx::IntRepr::U16 => quote!(u16), + witx::IntRepr::U32 => quote!(u32), + witx::IntRepr::U64 => quote!(u64), + } +} diff --git a/crates/memory/src/borrow.rs b/crates/memory/src/borrow.rs new file mode 100644 index 0000000000..50d5656f53 --- /dev/null +++ b/crates/memory/src/borrow.rs @@ -0,0 +1,59 @@ +use crate::region::Region; + +pub struct GuestBorrows { + immutable: Vec, + mutable: Vec, +} + +impl GuestBorrows { + pub fn new() -> Self { + GuestBorrows { + immutable: Vec::new(), + mutable: Vec::new(), + } + } + + fn is_borrowed_immut(&self, r: Region) -> bool { + !self.immutable.iter().all(|b| !b.overlaps(r)) + } + + fn is_borrowed_mut(&self, r: Region) -> bool { + !self.mutable.iter().all(|b| !b.overlaps(r)) + } + + pub fn borrow_immut(&mut self, r: Region) -> bool { + if self.is_borrowed_mut(r) { + return false; + } + self.immutable.push(r); + true + } + + pub fn unborrow_immut(&mut self, r: Region) { + let (ix, _) = self + .immutable + .iter() + .enumerate() + .find(|(_, reg)| r == **reg) + .expect("region exists in borrows"); + self.immutable.remove(ix); + } + + pub fn borrow_mut(&mut self, r: Region) -> bool { + if self.is_borrowed_immut(r) || self.is_borrowed_mut(r) { + return false; + } + self.mutable.push(r); + true + } + + pub fn unborrow_mut(&mut self, r: Region) { + let (ix, _) = self + .mutable + .iter() + .enumerate() + .find(|(_, reg)| r == **reg) + .expect("region exists in borrows"); + self.mutable.remove(ix); + } +} diff --git a/crates/memory/src/guest_type.rs b/crates/memory/src/guest_type.rs new file mode 100644 index 0000000000..c34aa0fce4 --- /dev/null +++ b/crates/memory/src/guest_type.rs @@ -0,0 +1,75 @@ +pub trait GuestType { + fn len() -> u32; +} + +impl GuestType for u8 { + fn len() -> u32 { + 1 + } +} + +impl GuestType for i8 { + fn len() -> u32 { + 1 + } +} + +impl GuestType for u16 { + fn len() -> u32 { + 2 + } +} + +impl GuestType for i16 { + fn len() -> u32 { + 2 + } +} + +impl GuestType for u32 { + fn len() -> u32 { + 4 + } +} + +impl GuestType for i32 { + fn len() -> u32 { + 4 + } +} + +impl GuestType for f32 { + fn len() -> u32 { + 4 + } +} + +impl GuestType for u64 { + fn len() -> u32 { + 8 + } +} + +impl GuestType for i64 { + fn len() -> u32 { + 8 + } +} + +impl GuestType for f64 { + fn len() -> u32 { + 8 + } +} + +impl GuestType for char { + fn len() -> u32 { + 1 + } +} + +impl GuestType for usize { + fn len() -> u32 { + 4 + } +} diff --git a/crates/memory/src/lib.rs b/crates/memory/src/lib.rs index aa768b498c..b8ade8806c 100644 --- a/crates/memory/src/lib.rs +++ b/crates/memory/src/lib.rs @@ -1,188 +1,8 @@ -#![allow(dead_code, unused)] // DURING DEVELOPMENT +mod borrow; +mod guest_type; +mod memory; +mod region; -use std::cell::RefCell; -use std::marker::PhantomData; -use std::rc::Rc; -use thiserror::Error; - -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub struct Region { - start: u32, - len: u32, -} - -impl Region { - fn overlaps(&self, rhs: Region) -> bool { - let self_start = self.start as u64; - let self_end = self.start as u64 + self.len as u64; - - let rhs_start = rhs.start as u64; - let rhs_end = rhs.start as u64 + rhs.len as u64; - - // start of rhs inside self: - if (rhs_start >= self_start && rhs_start < self_end) { - return true; - } - - // end of rhs inside self: - if (rhs_end >= self_start && rhs_end < self_end) { - return true; - } - - // start of self inside rhs: - if (self_start >= rhs_start && self_start < rhs_end) { - return true; - } - - // end of self inside rhs: XXX is this redundant? i suspect it is but im too tired - if (self_end >= rhs_start && self_end < rhs_end) { - return true; - } - - false - } -} - -struct GuestBorrows { - immutable: Vec, - mutable: Vec, -} - -impl GuestBorrows { - pub fn new() -> Self { - GuestBorrows { - immutable: Vec::new(), - mutable: Vec::new(), - } - } - - fn is_borrowed_immut(&self, r: Region) -> bool { - !self.immutable.iter().all(|b| !b.overlaps(r)) - } - - fn is_borrowed_mut(&self, r: Region) -> bool { - !self.mutable.iter().all(|b| !b.overlaps(r)) - } - - pub fn borrow_immut(&mut self, r: Region) -> bool { - if self.is_borrowed_mut(r) { - return false; - } - self.immutable.push(r); - true - } - - pub fn unborrow_immut(&mut self, r: Region) { - let (ix, _) = self - .immutable - .iter() - .enumerate() - .find(|(_, reg)| r == **reg) - .expect("region exists in borrows"); - self.immutable.remove(ix); - } - - pub fn borrow_mut(&mut self, r: Region) -> bool { - if self.is_borrowed_immut(r) || self.is_borrowed_mut(r) { - return false; - } - self.mutable.push(r); - true - } - - pub fn unborrow_mut(&mut self, r: Region) { - let (ix, _) = self - .mutable - .iter() - .enumerate() - .find(|(_, reg)| r == **reg) - .expect("region exists in borrows"); - self.mutable.remove(ix); - } -} - -pub struct GuestMemory<'a> { - ptr: *mut u8, - len: u32, - lifetime: PhantomData<&'a ()>, - borrows: Rc>, -} - -impl<'a> GuestMemory<'a> { - pub fn new(ptr: *mut u8, len: u32) -> GuestMemory<'a> { - GuestMemory { - ptr, - len, - lifetime: PhantomData, - borrows: Rc::new(RefCell::new(GuestBorrows::new())), - } - } - - fn contains(&self, r: Region) -> bool { - r.start < self.len - && r.len < self.len // make sure next clause doesnt underflow - && r.start < (self.len - r.len) - } - - pub fn ptr(&'a self, r: Region) -> Result>, MemoryError> { - let mut borrows = self.borrows.borrow_mut(); - if !self.contains(r) { - Err(MemoryError::OutOfBounds)?; - } - if borrows.borrow_immut(r) { - Ok(Some(GuestPtr { - mem: &self, - region: r, - })) - } else { - Ok(None) - } - } - - pub fn ptr_mut(&'a self, r: Region) -> Result>, MemoryError> { - let mut borrows = self.borrows.borrow_mut(); - if !self.contains(r) { - Err(MemoryError::OutOfBounds)?; - } - if borrows.borrow_immut(r) { - Ok(Some(GuestPtrMut { - mem: &self, - region: r, - })) - } else { - Ok(None) - } - } - -} - -pub struct GuestPtr<'a> { - mem: &'a GuestMemory<'a>, - region: Region, -} - -impl<'a> Drop for GuestPtr<'a> { - fn drop(&mut self) { - let mut borrows = self.mem.borrows.borrow_mut(); - borrows.unborrow_immut(self.region); - } -} - - -pub struct GuestPtrMut<'a> { - mem: &'a GuestMemory<'a>, - region: Region, -} - -impl<'a> Drop for GuestPtrMut<'a> { - fn drop(&mut self) { - let mut borrows = self.mem.borrows.borrow_mut(); - borrows.unborrow_mut(self.region); - } -} - -#[derive(Debug, Error)] -pub enum MemoryError { - #[error("Out of bounds")] - OutOfBounds, -} +pub use guest_type::GuestType; +pub use memory::{GuestMemory, GuestPtr, GuestPtrMut}; +pub use region::Region; diff --git a/crates/memory/src/memory.rs b/crates/memory/src/memory.rs new file mode 100644 index 0000000000..7cffb52606 --- /dev/null +++ b/crates/memory/src/memory.rs @@ -0,0 +1,117 @@ +use std::cell::RefCell; +use std::marker::PhantomData; +use std::rc::Rc; +use thiserror::Error; + +use crate::borrow::GuestBorrows; +use crate::guest_type::GuestType; +use crate::region::Region; + +pub struct GuestMemory<'a> { + ptr: *mut u8, + len: u32, + lifetime: PhantomData<&'a ()>, + borrows: Rc>, +} + +impl<'a> GuestMemory<'a> { + pub fn new(ptr: *mut u8, len: u32) -> GuestMemory<'a> { + GuestMemory { + ptr, + len, + lifetime: PhantomData, + borrows: Rc::new(RefCell::new(GuestBorrows::new())), + } + } + + fn contains(&self, r: Region) -> bool { + r.start < self.len + && r.len < self.len // make sure next clause doesnt underflow + && r.start < (self.len - r.len) + } + + pub fn ptr(&'a self, at: u32) -> Result, MemoryError> { + let r = Region { + start: at, + len: T::len(), + }; + let mut borrows = self.borrows.borrow_mut(); + if !self.contains(r) { + Err(MemoryError::OutOfBounds(r))?; + } + if borrows.borrow_immut(r) { + Ok(GuestPtr { + mem: &self, + region: r, + type_: PhantomData, + }) + } else { + Err(MemoryError::Borrowed(r)) + } + } + + pub fn ptr_mut(&'a self, at: u32) -> Result, MemoryError> { + let r = Region { + start: at, + len: T::len(), + }; + let mut borrows = self.borrows.borrow_mut(); + if !self.contains(r) { + Err(MemoryError::OutOfBounds(r))?; + } + if borrows.borrow_mut(r) { + Ok(GuestPtrMut { + mem: &self, + region: r, + type_: PhantomData, + }) + } else { + Err(MemoryError::Borrowed(r)) + } + } +} + +pub struct GuestPtr<'a, T> { + mem: &'a GuestMemory<'a>, + region: Region, + type_: PhantomData, +} + +impl<'a, T> GuestPtr<'a, T> { + pub fn ptr(&self) -> *const u8 { + (self.mem.ptr as usize + self.region.start as usize) as *const u8 + } +} + +impl<'a, T> Drop for GuestPtr<'a, T> { + fn drop(&mut self) { + let mut borrows = self.mem.borrows.borrow_mut(); + borrows.unborrow_immut(self.region); + } +} + +pub struct GuestPtrMut<'a, T> { + mem: &'a GuestMemory<'a>, + region: Region, + type_: PhantomData, +} + +impl<'a, T> GuestPtrMut<'a, T> { + pub fn ptr_mut(&self) -> *mut u8 { + (self.mem.ptr as usize + self.region.start as usize) as *mut u8 + } +} +impl<'a, T> Drop for GuestPtrMut<'a, T> { + fn drop(&mut self) { + let mut borrows = self.mem.borrows.borrow_mut(); + borrows.unborrow_mut(self.region); + } +} + +#[derive(Debug, Error)] +pub enum MemoryError { + #[error("Out of bounds: {0:?}")] + OutOfBounds(Region), + #[error("Borrowed: {0:?}")] + Borrowed(Region), +} diff --git a/crates/memory/src/region.rs b/crates/memory/src/region.rs new file mode 100644 index 0000000000..3df191117f --- /dev/null +++ b/crates/memory/src/region.rs @@ -0,0 +1,37 @@ +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct Region { + pub start: u32, + pub len: u32, +} + +impl Region { + pub fn overlaps(&self, rhs: Region) -> bool { + let self_start = self.start as u64; + let self_end = self.start as u64 + self.len as u64; + + let rhs_start = rhs.start as u64; + let rhs_end = rhs.start as u64 + rhs.len as u64; + + // start of rhs inside self: + if rhs_start >= self_start && rhs_start < self_end { + return true; + } + + // end of rhs inside self: + if rhs_end >= self_start && rhs_end < self_end { + return true; + } + + // start of self inside rhs: + if self_start >= rhs_start && self_start < rhs_end { + return true; + } + + // end of self inside rhs: XXX is this redundant? i suspect it is but im too tired + if self_end >= rhs_start && self_end < rhs_end { + return true; + } + + false + } +} diff --git a/test.witx b/test.witx index f1185b63fe..414b9617cd 100644 --- a/test.witx +++ b/test.witx @@ -8,6 +8,7 @@ (module $foo (@interface func (export "bar") - (param $name string) + (param $an_int (@witx pointer u32)) + (param $an_float (@witx pointer f32)) (result $error $errno)) )