diff --git a/crates/generate/src/funcs.rs b/crates/generate/src/funcs.rs index b2728db3ca..57a9e63da7 100644 --- a/crates/generate/src/funcs.rs +++ b/crates/generate/src/funcs.rs @@ -181,7 +181,36 @@ fn marshal_arg( let #name = #name as #interface_typename; } } - witx::BuiltinType::String => unimplemented!("string types unimplemented"), + witx::BuiltinType::String => { + let lifetime = anon_lifetime(); + let ptr_name = names.func_ptr_binding(¶m.name); + let len_name = names.func_len_binding(¶m.name); + let name = names.func_param(¶m.name); + quote! { + let num_elems = match memory.ptr::(#len_name as u32) { + Ok(p) => match p.as_ref() { + Ok(r) => r, + Err(e) => { + #error_handling + } + } + Err(e) => { + #error_handling + } + }; + let #name: wiggle_runtime::GuestString<#lifetime> = match memory.ptr::(#ptr_name as u32) { + Ok(p) => match p.array(*num_elems) { + Ok(s) => s.into(), + Err(e) => { + #error_handling + } + } + Err(e) => { + #error_handling + } + }; + } + } }, witx::Type::Pointer(pointee) => { let pointee_type = names.type_ref(pointee, anon_lifetime()); diff --git a/crates/generate/src/names.rs b/crates/generate/src/names.rs index e673af3197..3337a18619 100644 --- a/crates/generate/src/names.rs +++ b/crates/generate/src/names.rs @@ -22,9 +22,9 @@ impl Names { let ident = format_ident!("{}", id.as_str().to_camel_case()); quote!(#ident) } - pub fn builtin_type(&self, b: BuiltinType) -> TokenStream { + pub fn builtin_type(&self, b: BuiltinType, lifetime: TokenStream) -> TokenStream { match b { - BuiltinType::String => quote!(String), + BuiltinType::String => quote!(wiggle_runtime::GuestString<#lifetime>), BuiltinType::U8 => quote!(u8), BuiltinType::U16 => quote!(u16), BuiltinType::U32 => quote!(u32), @@ -35,7 +35,7 @@ impl Names { BuiltinType::S64 => quote!(i64), BuiltinType::F32 => quote!(f32), BuiltinType::F64 => quote!(f64), - BuiltinType::Char8 => quote!(char), + BuiltinType::Char8 => quote!(u8), BuiltinType::USize => quote!(usize), } } @@ -59,7 +59,7 @@ impl Names { } } TypeRef::Value(ty) => match &**ty { - witx::Type::Builtin(builtin) => self.builtin_type(*builtin), + witx::Type::Builtin(builtin) => self.builtin_type(*builtin, lifetime.clone()), witx::Type::Pointer(pointee) => { let pointee_type = self.type_ref(&pointee, lifetime.clone()); quote!(wiggle_runtime::GuestPtrMut<#lifetime, #pointee_type>) diff --git a/crates/generate/src/types.rs b/crates/generate/src/types.rs index d2fc80fbce..61e21b71a5 100644 --- a/crates/generate/src/types.rs +++ b/crates/generate/src/types.rs @@ -285,14 +285,18 @@ fn define_enum(names: &Names, name: &witx::Id, e: &witx::EnumDatatype) -> TokenS fn define_builtin(names: &Names, name: &witx::Id, builtin: witx::BuiltinType) -> TokenStream { let ident = names.type_(name); - let built = names.builtin_type(builtin); - quote!(pub type #ident = #built;) + let built = names.builtin_type(builtin, quote!('a)); + if let witx::BuiltinType::String = builtin { + quote!(pub type #ident<'a> = #built;) + } else { + quote!(pub type #ident = #built;) + } } pub fn type_needs_lifetime(tref: &witx::TypeRef) -> bool { match &*tref.type_() { witx::Type::Builtin(b) => match b { - witx::BuiltinType::String => unimplemented!(), + witx::BuiltinType::String => true, _ => false, }, witx::Type::Enum { .. } @@ -392,7 +396,7 @@ fn define_ptr_struct(names: &Names, name: &witx::Id, s: &witx::StructDatatype) - let type_ = match &m.tref { witx::TypeRef::Name(nt) => names.type_(&nt.name), witx::TypeRef::Value(ty) => match &**ty { - witx::Type::Builtin(builtin) => names.builtin_type(*builtin), + witx::Type::Builtin(builtin) => names.builtin_type(*builtin, quote!('a)), witx::Type::Pointer(pointee) => { let pointee_type = names.type_ref(&pointee, quote!('a)); quote!(wiggle_runtime::GuestPtrMut<'a, #pointee_type>) @@ -410,7 +414,7 @@ fn define_ptr_struct(names: &Names, name: &witx::Id, s: &witx::StructDatatype) - let type_ = match &ml.member.tref { witx::TypeRef::Name(nt) => names.type_(&nt.name), witx::TypeRef::Value(ty) => match &**ty { - witx::Type::Builtin(builtin) => names.builtin_type(*builtin), + witx::Type::Builtin(builtin) => names.builtin_type(*builtin, quote!('a)), witx::Type::Pointer(pointee) => { let pointee_type = names.type_ref(&pointee, anon_lifetime()); quote!(wiggle_runtime::GuestPtrMut::<#pointee_type>) @@ -453,7 +457,7 @@ fn define_ptr_struct(names: &Names, name: &witx::Id, s: &witx::StructDatatype) - } witx::TypeRef::Value(ty) => match &**ty { witx::Type::Builtin(builtin) => { - let type_ = names.builtin_type(*builtin); + let type_ = names.builtin_type(*builtin, anon_lifetime()); quote! { let #name = #type_::read_from_guest(&location.cast(#offset)?)?; } diff --git a/crates/runtime/src/error.rs b/crates/runtime/src/error.rs index 6768f67060..9c9e8a8e9a 100644 --- a/crates/runtime/src/error.rs +++ b/crates/runtime/src/error.rs @@ -27,4 +27,6 @@ pub enum GuestError { #[source] err: Box, }, + #[error("Invalid UTF-8 encountered")] + InvalidUtf8(#[from] std::str::Utf8Error), } diff --git a/crates/runtime/src/lib.rs b/crates/runtime/src/lib.rs index b07da6a21d..839c646b0b 100644 --- a/crates/runtime/src/lib.rs +++ b/crates/runtime/src/lib.rs @@ -6,5 +6,8 @@ mod region; pub use error::GuestError; pub use guest_type::{GuestErrorType, GuestType, GuestTypeClone, GuestTypeCopy}; -pub use memory::{GuestArray, GuestMemory, GuestPtr, GuestPtrMut, GuestRef, GuestRefMut}; +pub use memory::{ + GuestArray, GuestMemory, GuestPtr, GuestPtrMut, GuestRef, GuestRefMut, GuestString, + GuestStringRef, +}; pub use region::Region; diff --git a/crates/runtime/src/memory/mod.rs b/crates/runtime/src/memory/mod.rs index 5892df29da..2800da56f1 100644 --- a/crates/runtime/src/memory/mod.rs +++ b/crates/runtime/src/memory/mod.rs @@ -1,8 +1,10 @@ mod array; mod ptr; +mod string; pub use array::*; pub use ptr::*; +pub use string::*; use crate::{borrow::GuestBorrows, GuestError, GuestType, Region}; use std::{cell::RefCell, fmt, marker::PhantomData, rc::Rc}; diff --git a/crates/runtime/src/memory/string.rs b/crates/runtime/src/memory/string.rs new file mode 100644 index 0000000000..a4e6deac6d --- /dev/null +++ b/crates/runtime/src/memory/string.rs @@ -0,0 +1,124 @@ +use super::array::{GuestArray, GuestArrayRef}; +use crate::GuestError; +use std::fmt; + +pub struct GuestString<'a> { + pub(super) array: GuestArray<'a, u8>, +} + +impl<'a> fmt::Debug for GuestString<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "GuestString {{ array: {:?} }}", self.array) + } +} + +impl<'a> GuestString<'a> { + pub fn as_ref(&self) -> Result, GuestError> { + let ref_ = self.array.as_ref()?; + Ok(GuestStringRef { ref_ }) + } + + pub fn to_string(&self) -> Result { + Ok(self.as_ref()?.as_str()?.to_owned()) + } +} + +impl<'a> From> for GuestString<'a> { + fn from(array: GuestArray<'a, u8>) -> Self { + Self { array } + } +} + +pub struct GuestStringRef<'a> { + pub(super) ref_: GuestArrayRef<'a, u8>, +} + +impl<'a> fmt::Debug for GuestStringRef<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "GuestStringRef {{ ref_: {:?} }}", self.ref_) + } +} + +impl<'a> GuestStringRef<'a> { + pub fn as_str(&self) -> Result<&str, GuestError> { + std::str::from_utf8(&*self.ref_).map_err(Into::into) + } +} + +#[cfg(test)] +mod test { + use super::{ + super::{ + ptr::{GuestPtr, GuestPtrMut}, + GuestError, GuestMemory, + }, + GuestString, + }; + + #[repr(align(4096))] + struct HostMemory { + buffer: [u8; 4096], + } + + impl HostMemory { + pub fn new() -> Self { + Self { buffer: [0; 4096] } + } + pub fn as_mut_ptr(&mut self) -> *mut u8 { + self.buffer.as_mut_ptr() + } + pub fn len(&self) -> usize { + self.buffer.len() + } + } + + #[test] + fn valid_utf8() { + let mut host_memory = HostMemory::new(); + let guest_memory = GuestMemory::new(host_memory.as_mut_ptr(), host_memory.len() as u32); + // write string into memory + let mut ptr: GuestPtrMut = guest_memory.ptr_mut(0).expect("ptr mut to start of string"); + let input_str = "cześć WASI!"; + for byte in input_str.as_bytes() { + let mut ref_mut = ptr.as_ref_mut().expect("valid deref"); + *ref_mut = *byte; + ptr = ptr.elem(1).expect("next ptr"); + } + // read the string as GuestString + let ptr: GuestPtr = guest_memory.ptr(0).expect("ptr to start of string"); + let guest_string: GuestString<'_> = ptr + .array(input_str.len() as u32) + .expect("valid null-terminated string") + .into(); + let as_ref = guest_string.as_ref().expect("deref"); + assert_eq!(as_ref.as_str().expect("valid UTF-8"), input_str); + } + + #[test] + fn invalid_utf8() { + let mut host_memory = HostMemory::new(); + let guest_memory = GuestMemory::new(host_memory.as_mut_ptr(), host_memory.len() as u32); + // write string into memory + let mut ptr: GuestPtrMut = guest_memory.ptr_mut(0).expect("ptr mut to start of string"); + let input_str = "cześć WASI!"; + let mut bytes = input_str.as_bytes().to_vec(); + // insert 0xFE which is an invalid UTF-8 byte + bytes[5] = 0xfe; + for byte in &bytes { + let mut ref_mut = ptr.as_ref_mut().expect("valid deref"); + *ref_mut = *byte; + ptr = ptr.elem(1).expect("next ptr"); + } + // read the string as GuestString + let ptr: GuestPtr = guest_memory.ptr(0).expect("ptr to start of string"); + let guest_string: GuestString<'_> = ptr + .array(bytes.len() as u32) + .expect("valid null-terminated string") + .into(); + let as_ref = guest_string.as_ref().expect("deref"); + match as_ref.as_str().expect_err("should fail") { + GuestError::InvalidUtf8(_) => {} + x => assert!(false, "expected GuestError::InvalidUtf8(_), got {:?}", x), + } + } +} diff --git a/tests/main.rs b/tests/main.rs index 41aae0f5ab..a64395ffa3 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -2,7 +2,7 @@ use proptest::prelude::*; use std::convert::TryFrom; use wiggle_runtime::{ GuestArray, GuestError, GuestErrorType, GuestMemory, GuestPtr, GuestPtrMut, GuestRef, - GuestRefMut, + GuestRefMut, GuestString, }; wiggle_generate::from_witx!({ @@ -136,6 +136,13 @@ impl foo::Foo for WasiCtx { })?; Ok(old_config ^ other_config) } + + fn hello_string(&mut self, a_string: &GuestString<'_>) -> Result { + let as_ref = a_string.as_ref().expect("deref ptr should succeed"); + let as_str = as_ref.as_str().expect("valid UTF-8 string"); + println!("a_string='{}'", as_str); + Ok(as_str.len() as u32) + } } // Errno is used as a first return value in the functions above, therefore // it must implement GuestErrorType with type Context = WasiCtx. @@ -867,3 +874,87 @@ proptest! { e.test() } } + +fn test_string_strategy() -> impl Strategy { + "\\p{Greek}{1,256}" +} + +#[derive(Debug)] +struct HelloStringExercise { + test_word: String, + string_ptr_loc: MemArea, + string_len_loc: MemArea, + return_ptr_loc: MemArea, +} + +impl HelloStringExercise { + pub fn strat() -> BoxedStrategy { + (test_string_strategy(),) + .prop_flat_map(|(test_word,)| { + ( + Just(test_word.clone()), + HostMemory::mem_area_strat(test_word.len() as u32), + HostMemory::mem_area_strat(4), + HostMemory::mem_area_strat(4), + ) + }) + .prop_map( + |(test_word, string_ptr_loc, string_len_loc, return_ptr_loc)| Self { + test_word, + string_ptr_loc, + string_len_loc, + return_ptr_loc, + }, + ) + .prop_filter("non-overlapping pointers", |e| { + non_overlapping_set(&[&e.string_ptr_loc, &e.string_len_loc, &e.return_ptr_loc]) + }) + .boxed() + } + + pub fn test(&self) { + let mut ctx = WasiCtx::new(); + let mut host_memory = HostMemory::new(); + let mut guest_memory = GuestMemory::new(host_memory.as_mut_ptr(), host_memory.len() as u32); + + // Populate string length + *guest_memory + .ptr_mut(self.string_len_loc.ptr) + .expect("ptr mut to string len") + .as_ref_mut() + .expect("deref ptr mut to string len") = self.test_word.len() as u32; + + // Populate string in guest's memory + { + let mut next: GuestPtrMut<'_, u8> = guest_memory + .ptr_mut(self.string_ptr_loc.ptr) + .expect("ptr mut to the first byte of string"); + for byte in self.test_word.as_bytes() { + *next.as_ref_mut().expect("deref mut") = *byte; + next = next.elem(1).expect("increment ptr by 1"); + } + } + + let res = foo::hello_string( + &mut ctx, + &mut guest_memory, + self.string_ptr_loc.ptr as i32, + self.string_len_loc.ptr as i32, + self.return_ptr_loc.ptr as i32, + ); + assert_eq!(res, types::Errno::Ok.into(), "hello string errno"); + + let given = *guest_memory + .ptr::(self.return_ptr_loc.ptr) + .expect("ptr to return value") + .as_ref() + .expect("deref ptr to return value"); + assert_eq!(self.test_word.len() as u32, given); + } +} +proptest! { + #[test] + fn hello_string(e in HelloStringExercise::strat()) { + e.test() + } +} diff --git a/tests/test.witx b/tests/test.witx index e7dda02d5f..a6edac306f 100644 --- a/tests/test.witx +++ b/tests/test.witx @@ -72,4 +72,9 @@ (result $error $errno) (result $new_config $car_config) ) + (@interface func (export "hello_string") + (param $a_string string) + (result $error $errno) + (result $total_bytes u32) + ) )