diff --git a/crates/wasi-common/src/file.rs b/crates/wasi-common/src/file.rs index 0c6c9b2a3e..0de79e3128 100644 --- a/crates/wasi-common/src/file.rs +++ b/crates/wasi-common/src/file.rs @@ -25,6 +25,26 @@ pub trait WasiFile: Send + Sync { Err(Error::badf()) } + async fn sock_recv<'a>( + &mut self, + _ri_data: &mut [std::io::IoSliceMut<'a>], + _ri_flags: RiFlags, + ) -> Result<(u64, RoFlags), Error> { + Err(Error::badf()) + } + + async fn sock_send<'a>( + &mut self, + _si_data: &[std::io::IoSlice<'a>], + _si_flags: SiFlags, + ) -> Result { + Err(Error::badf()) + } + + async fn sock_shutdown(&mut self, _how: SdFlags) -> Result<(), Error> { + Err(Error::badf()) + } + async fn datasync(&mut self) -> Result<(), Error> { Ok(()) } @@ -145,6 +165,31 @@ bitflags! { } } +bitflags! { + pub struct SdFlags: u32 { + const RD = 0b1; + const WR = 0b10; + } +} + +bitflags! { + pub struct SiFlags: u32 { + } +} + +bitflags! { + pub struct RiFlags: u32 { + const RECV_PEEK = 0b1; + const RECV_WAITALL = 0b10; + } +} + +bitflags! { + pub struct RoFlags: u32 { + const RECV_DATA_TRUNCATED = 0b1; + } +} + bitflags! { pub struct OFlags: u32 { const CREATE = 0b1; diff --git a/crates/wasi-common/src/snapshots/preview_1.rs b/crates/wasi-common/src/snapshots/preview_1.rs index 9c6f372d3d..2aa34874eb 100644 --- a/crates/wasi-common/src/snapshots/preview_1.rs +++ b/crates/wasi-common/src/snapshots/preview_1.rs @@ -2,7 +2,7 @@ use crate::{ dir::{DirCaps, DirEntry, DirEntryExt, DirFdStat, ReaddirCursor, ReaddirEntity, TableDirExt}, file::{ Advice, FdFlags, FdStat, FileCaps, FileEntry, FileEntryExt, FileType, Filestat, OFlags, - TableFileExt, WasiFile, + RiFlags, RoFlags, SdFlags, SiFlags, TableFileExt, WasiFile, }, sched::{ subscription::{RwEventFlags, SubscriptionResult}, @@ -1164,24 +1164,69 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { async fn sock_recv<'a>( &mut self, - _fd: types::Fd, - _ri_data: &types::IovecArray<'a>, - _ri_flags: types::Riflags, + fd: types::Fd, + ri_data: &types::IovecArray<'a>, + ri_flags: types::Riflags, ) -> Result<(types::Size, types::Roflags), Error> { - Err(Error::trap("sock_recv unsupported")) + let f = self + .table() + .get_file_mut(u32::from(fd))? + .get_cap_mut(FileCaps::READ)?; + + let mut guest_slices: Vec> = ri_data + .iter() + .map(|iov_ptr| { + let iov_ptr = iov_ptr?; + let iov: types::Iovec = iov_ptr.read()?; + Ok(iov.buf.as_array(iov.buf_len).as_slice_mut()?) + }) + .collect::>()?; + + let mut ioslices: Vec = guest_slices + .iter_mut() + .map(|s| IoSliceMut::new(&mut *s)) + .collect(); + + let (bytes_read, roflags) = f.sock_recv(&mut ioslices, RiFlags::from(ri_flags)).await?; + Ok((types::Size::try_from(bytes_read)?, roflags.into())) } async fn sock_send<'a>( &mut self, - _fd: types::Fd, - _si_data: &types::CiovecArray<'a>, + fd: types::Fd, + si_data: &types::CiovecArray<'a>, _si_flags: types::Siflags, ) -> Result { - Err(Error::trap("sock_send unsupported")) + let f = self + .table() + .get_file_mut(u32::from(fd))? + .get_cap_mut(FileCaps::WRITE)?; + + let guest_slices: Vec> = si_data + .iter() + .map(|iov_ptr| { + let iov_ptr = iov_ptr?; + let iov: types::Ciovec = iov_ptr.read()?; + Ok(iov.buf.as_array(iov.buf_len).as_slice()?) + }) + .collect::>()?; + + let ioslices: Vec = guest_slices + .iter() + .map(|s| IoSlice::new(s.deref())) + .collect(); + let bytes_written = f.sock_send(&ioslices, SiFlags::empty()).await?; + + Ok(types::Size::try_from(bytes_written)?) } - async fn sock_shutdown(&mut self, _fd: types::Fd, _how: types::Sdflags) -> Result<(), Error> { - Err(Error::trap("sock_shutdown unsupported")) + async fn sock_shutdown(&mut self, fd: types::Fd, how: types::Sdflags) -> Result<(), Error> { + let f = self + .table() + .get_file_mut(u32::from(fd))? + .get_cap_mut(FileCaps::FDSTAT_SET_FLAGS)?; + + f.sock_shutdown(SdFlags::from(how)).await } } @@ -1477,6 +1522,12 @@ convert_flags_bidirectional!( SYNC ); +convert_flags_bidirectional!(RiFlags, types::Riflags, RECV_PEEK, RECV_WAITALL); + +convert_flags_bidirectional!(RoFlags, types::Roflags, RECV_DATA_TRUNCATED); + +convert_flags_bidirectional!(SdFlags, types::Sdflags, RD, WR); + impl From<&types::Oflags> for OFlags { fn from(oflags: &types::Oflags) -> OFlags { let mut out = OFlags::empty();