diff --git a/crates/test-programs/wasi-tests/src/bin/poll_oneoff.rs b/crates/test-programs/wasi-tests/src/bin/poll_oneoff_files.rs similarity index 100% rename from crates/test-programs/wasi-tests/src/bin/poll_oneoff.rs rename to crates/test-programs/wasi-tests/src/bin/poll_oneoff_files.rs diff --git a/crates/wasi-common/cap-std-sync/src/sched/unix.rs b/crates/wasi-common/cap-std-sync/src/sched/unix.rs index 030c1bcf68..abf0338b1e 100644 --- a/crates/wasi-common/cap-std-sync/src/sched/unix.rs +++ b/crates/wasi-common/cap-std-sync/src/sched/unix.rs @@ -1,6 +1,5 @@ use cap_std::time::Duration; use std::convert::TryInto; -use std::ops::Deref; use std::os::unix::io::{AsRawFd, RawFd}; use wasi_common::{ file::WasiFile, @@ -23,23 +22,22 @@ impl SyncSched { #[wiggle::async_trait] impl WasiSched for SyncSched { - async fn poll_oneoff<'a>(&self, poll: &'_ Poll<'a>) -> Result<(), Error> { + async fn poll_oneoff<'a>(&self, poll: &'a Poll<'a>) -> Result<(), Error> { if poll.is_empty() { return Ok(()); } let mut pollfds = Vec::new(); - let timeout = poll.earliest_clock_deadline(); for s in poll.rw_subscriptions() { match s { Subscription::Read(f) => { - let raw_fd = wasi_file_raw_fd(f.file()?.deref()).ok_or( + let raw_fd = wasi_file_raw_fd(f.file).ok_or( Error::invalid_argument().context("read subscription fd downcast failed"), )?; pollfds.push(unsafe { PollFd::new(raw_fd, PollFlags::POLLIN) }); } Subscription::Write(f) => { - let raw_fd = wasi_file_raw_fd(f.file()?.deref()).ok_or( + let raw_fd = wasi_file_raw_fd(f.file).ok_or( Error::invalid_argument().context("write subscription fd downcast failed"), )?; pollfds.push(unsafe { PollFd::new(raw_fd, PollFlags::POLLOUT) }); @@ -49,7 +47,7 @@ impl WasiSched for SyncSched { } let ready = loop { - let poll_timeout = if let Some(t) = timeout { + let poll_timeout = if let Some(t) = poll.earliest_clock_deadline() { let duration = t.duration_until().unwrap_or(Duration::from_secs(0)); (duration.as_millis() + 1) // XXX try always rounding up? .try_into() @@ -79,11 +77,7 @@ impl WasiSched for SyncSched { if let Some(revents) = pollfd.revents() { let (nbytes, rwsub) = match rwsub { Subscription::Read(sub) => { - let ready = sub - .file() - .expect("validated file already") - .num_ready_bytes() - .await?; + let ready = sub.file.num_ready_bytes().await?; (std::cmp::max(ready, 1), sub) } Subscription::Write(sub) => (0, sub), @@ -101,7 +95,7 @@ impl WasiSched for SyncSched { } } } else { - timeout + poll.earliest_clock_deadline() .expect("timed out") .result() .expect("timer deadline is past") diff --git a/crates/wasi-common/src/error.rs b/crates/wasi-common/src/error.rs index cf132b59ee..20277554fc 100644 --- a/crates/wasi-common/src/error.rs +++ b/crates/wasi-common/src/error.rs @@ -23,7 +23,7 @@ //! The real value of using `anyhow::Error` here is being able to use //! `anyhow::Result::context` to aid in debugging of errors. -pub use anyhow::Error; +pub use anyhow::{Context, Error}; /// Internal error type for the `wasi-common` crate. /// Contains variants of the WASI `$errno` type are added according to what is actually used internally by diff --git a/crates/wasi-common/src/lib.rs b/crates/wasi-common/src/lib.rs index 4575b423b6..63910d4a60 100644 --- a/crates/wasi-common/src/lib.rs +++ b/crates/wasi-common/src/lib.rs @@ -66,7 +66,7 @@ pub use cap_rand::RngCore; pub use clocks::{SystemTimeSpec, WasiClocks, WasiMonotonicClock, WasiSystemClock}; pub use ctx::{WasiCtx, WasiCtxBuilder}; pub use dir::WasiDir; -pub use error::{Error, ErrorExt, ErrorKind}; +pub use error::{Context, Error, ErrorExt, ErrorKind}; pub use file::WasiFile; pub use sched::{Poll, WasiSched}; pub use string_array::StringArrayError; diff --git a/crates/wasi-common/src/sched.rs b/crates/wasi-common/src/sched.rs index 984b8d9848..e437f48727 100644 --- a/crates/wasi-common/src/sched.rs +++ b/crates/wasi-common/src/sched.rs @@ -1,8 +1,7 @@ use crate::clocks::WasiMonotonicClock; -use crate::table::Table; -use crate::{Error, ErrorExt}; +use crate::file::WasiFile; +use crate::Error; use cap_std::time::Instant; -use std::collections::HashSet; pub mod subscription; pub use cap_std::time::Duration; @@ -10,11 +9,12 @@ use subscription::{MonotonicClockSubscription, RwSubscription, Subscription, Sub #[wiggle::async_trait] pub trait WasiSched { - async fn poll_oneoff<'a>(&self, poll: &Poll<'a>) -> Result<(), Error>; + async fn poll_oneoff<'a>(&self, poll: &'a Poll<'a>) -> Result<(), Error>; async fn sched_yield(&self) -> Result<(), Error>; async fn sleep(&self, duration: Duration) -> Result<(), Error>; } +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub struct Userdata(u64); impl From for Userdata { fn from(u: u64) -> Userdata { @@ -28,19 +28,15 @@ impl From for u64 { } } +pub type PollResults = Vec<(SubscriptionResult, Userdata)>; + pub struct Poll<'a> { - table: &'a Table, - fds: HashSet, subs: Vec<(Subscription<'a>, Userdata)>, } impl<'a> Poll<'a> { - pub fn new(table: &'a Table) -> Self { - Self { - table, - fds: HashSet::new(), - subs: Vec::new(), - } + pub fn new() -> Self { + Self { subs: Vec::new() } } pub fn subscribe_monotonic_clock( &mut self, @@ -58,36 +54,18 @@ impl<'a> Poll<'a> { ud, )); } - pub fn subscribe_read(&mut self, fd: u32, ud: Userdata) -> Result<(), Error> { - if self.fds.contains(&fd) { - return Err( - Error::invalid_argument().context("Fd can be subscribed to at most once per poll") - ); - } else { - self.fds.insert(fd); - } + pub fn subscribe_read(&mut self, file: &'a mut dyn WasiFile, ud: Userdata) { self.subs - .push((Subscription::Read(RwSubscription::new(self.table, fd)?), ud)); - Ok(()) + .push((Subscription::Read(RwSubscription::new(file)), ud)); } - pub fn subscribe_write(&mut self, fd: u32, ud: Userdata) -> Result<(), Error> { - if self.fds.contains(&fd) { - return Err( - Error::invalid_argument().context("Fd can be subscribed to at most once per poll") - ); - } else { - self.fds.insert(fd); - } - self.subs.push(( - Subscription::Write(RwSubscription::new(self.table, fd)?), - ud, - )); - Ok(()) - } - pub fn results(self) -> Vec<(SubscriptionResult, Userdata)> { + pub fn subscribe_write(&mut self, file: &'a mut dyn WasiFile, ud: Userdata) { self.subs - .into_iter() - .filter_map(|(s, ud)| SubscriptionResult::from_subscription(s).map(|r| (r, ud))) + .push((Subscription::Write(RwSubscription::new(file)), ud)); + } + pub fn results(&self) -> Vec<(SubscriptionResult, Userdata)> { + self.subs + .iter() + .filter_map(|(s, ud)| SubscriptionResult::from_subscription(s).map(|r| (r, *ud))) .collect() } pub fn is_empty(&self) -> bool { diff --git a/crates/wasi-common/src/sched/subscription.rs b/crates/wasi-common/src/sched/subscription.rs index 73575ae69c..d333b74391 100644 --- a/crates/wasi-common/src/sched/subscription.rs +++ b/crates/wasi-common/src/sched/subscription.rs @@ -1,10 +1,9 @@ use crate::clocks::WasiMonotonicClock; -use crate::file::{FileCaps, FileEntryMutExt, TableFileExt, WasiFile}; -use crate::table::Table; +use crate::file::WasiFile; use crate::Error; use bitflags::bitflags; use cap_std::time::{Duration, Instant}; -use std::cell::{Cell, RefMut}; +use std::cell::Cell; bitflags! { pub struct RwEventFlags: u32 { @@ -13,29 +12,16 @@ bitflags! { } pub struct RwSubscription<'a> { - table: &'a Table, - fd: u32, + pub file: &'a mut dyn WasiFile, status: Cell>>, } impl<'a> RwSubscription<'a> { - /// Create an RwSubscription. This constructor checks to make sure the file we need exists, and - /// has the correct rights. But, we can't hold onto the WasiFile RefMut inside this structure - /// (Pat can't convince borrow checker, either not clever enough or a rustc bug), so we need to - /// re-borrow at use time. - pub fn new(table: &'a Table, fd: u32) -> Result { - let _ = table.get_file_mut(fd)?.get_cap(FileCaps::POLL_READWRITE)?; - Ok(Self { - table, - fd, + pub fn new(file: &'a mut dyn WasiFile) -> Self { + Self { + file, status: Cell::new(None), - }) - } - /// This accessor could fail if there is an outstanding borrow of the file. - pub fn file(&self) -> Result, Error> { - self.table - .get_file_mut(self.fd)? - .get_cap(FileCaps::POLL_READWRITE) + } } pub fn complete(&self, size: u64, flags: RwEventFlags) { self.status.set(Some(Ok((size, flags)))) @@ -43,8 +29,8 @@ impl<'a> RwSubscription<'a> { pub fn error(&self, error: Error) { self.status.set(Some(Err(error))) } - pub fn result(self) -> Option> { - self.status.into_inner() + pub fn result(&self) -> Option> { + self.status.take() } } @@ -83,7 +69,7 @@ pub enum SubscriptionResult { } impl SubscriptionResult { - pub fn from_subscription(s: Subscription) -> Option { + pub fn from_subscription(s: &Subscription) -> Option { match s { Subscription::Read(s) => s.result().map(SubscriptionResult::Read), Subscription::Write(s) => s.result().map(SubscriptionResult::Write), diff --git a/crates/wasi-common/src/snapshots/preview_0.rs b/crates/wasi-common/src/snapshots/preview_0.rs index 6028bf12da..f933dfb04d 100644 --- a/crates/wasi-common/src/snapshots/preview_0.rs +++ b/crates/wasi-common/src/snapshots/preview_0.rs @@ -1,12 +1,14 @@ -use crate::file::{FileCaps, FileEntryExt, TableFileExt}; +use crate::file::{FileCaps, FileEntryExt, FileEntryMutExt, TableFileExt, WasiFile}; use crate::sched::{ subscription::{RwEventFlags, SubscriptionResult}, - Poll, + Poll, Userdata, }; use crate::snapshots::preview_1::types as snapshot1_types; use crate::snapshots::preview_1::wasi_snapshot_preview1::WasiSnapshotPreview1 as Snapshot1; use crate::{Error, ErrorExt, WasiCtx}; use cap_std::time::Duration; +use std::cell::RefMut; +use std::collections::HashSet; use std::convert::{TryFrom, TryInto}; use std::io::{IoSlice, IoSliceMut}; use std::ops::Deref; @@ -778,7 +780,11 @@ impl wasi_unstable::WasiUnstable for WasiCtx { } let table = self.table(); - let mut poll = Poll::new(&table); + let mut sub_fds: HashSet = HashSet::new(); + // We need these refmuts to outlive Poll, which will hold the &mut dyn WasiFile inside + let mut read_refs: Vec<(RefMut<'_, dyn WasiFile>, Userdata)> = Vec::new(); + let mut write_refs: Vec<(RefMut<'_, dyn WasiFile>, Userdata)> = Vec::new(); + let mut poll = Poll::new(); let subs = subs.as_array(nsubscriptions); for sub_elem in subs.iter() { @@ -816,11 +822,29 @@ impl wasi_unstable::WasiUnstable for WasiCtx { }, types::SubscriptionU::FdRead(readsub) => { let fd = readsub.file_descriptor; - poll.subscribe_read(u32::from(fd), sub.userdata.into())?; + if sub_fds.contains(&fd) { + return Err(Error::invalid_argument() + .context("Fd can be subscribed to at most once per poll")); + } else { + sub_fds.insert(fd); + } + let file_ref = table + .get_file_mut(u32::from(fd))? + .get_cap(FileCaps::POLL_READWRITE)?; + read_refs.push((file_ref, sub.userdata.into())); } types::SubscriptionU::FdWrite(writesub) => { let fd = writesub.file_descriptor; - poll.subscribe_write(u32::from(fd), sub.userdata.into())?; + if sub_fds.contains(&fd) { + return Err(Error::invalid_argument() + .context("Fd can be subscribed to at most once per poll")); + } else { + sub_fds.insert(fd); + } + let file_ref = table + .get_file_mut(u32::from(fd))? + .get_cap(FileCaps::POLL_READWRITE)?; + write_refs.push((file_ref, sub.userdata.into())); } } } diff --git a/crates/wasi-common/src/snapshots/preview_1.rs b/crates/wasi-common/src/snapshots/preview_1.rs index 2487ef339e..2350e8663b 100644 --- a/crates/wasi-common/src/snapshots/preview_1.rs +++ b/crates/wasi-common/src/snapshots/preview_1.rs @@ -2,17 +2,18 @@ use crate::{ dir::{DirCaps, DirEntry, DirEntryExt, DirFdStat, ReaddirCursor, ReaddirEntity, TableDirExt}, file::{ Advice, FdFlags, FdStat, FileCaps, FileEntry, FileEntryExt, FileEntryMutExt, FileType, - Filestat, OFlags, TableFileExt, + Filestat, OFlags, TableFileExt, WasiFile, }, sched::{ subscription::{RwEventFlags, SubscriptionResult}, - Poll, + Poll, Userdata, }, Error, ErrorExt, ErrorKind, SystemTimeSpec, WasiCtx, }; use anyhow::Context; use cap_std::time::{Duration, SystemClock}; use std::cell::{Ref, RefMut}; +use std::collections::HashSet; use std::convert::{TryFrom, TryInto}; use std::io::{IoSlice, IoSliceMut}; use std::ops::{Deref, DerefMut}; @@ -970,7 +971,11 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { } let table = self.table(); - let mut poll = Poll::new(&table); + let mut sub_fds: HashSet = HashSet::new(); + // We need these refmuts to outlive Poll, which will hold the &mut dyn WasiFile inside + let mut read_refs: Vec<(RefMut<'_, dyn WasiFile>, Userdata)> = Vec::new(); + let mut write_refs: Vec<(RefMut<'_, dyn WasiFile>, Userdata)> = Vec::new(); + let mut poll = Poll::new(); let subs = subs.as_array(nsubscriptions); for sub_elem in subs.iter() { @@ -1008,15 +1013,40 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { }, types::SubscriptionU::FdRead(readsub) => { let fd = readsub.file_descriptor; - poll.subscribe_read(u32::from(fd), sub.userdata.into())?; + if sub_fds.contains(&fd) { + return Err(Error::invalid_argument() + .context("Fd can be subscribed to at most once per poll")); + } else { + sub_fds.insert(fd); + } + let file_ref = table + .get_file_mut(u32::from(fd))? + .get_cap(FileCaps::POLL_READWRITE)?; + read_refs.push((file_ref, sub.userdata.into())); } types::SubscriptionU::FdWrite(writesub) => { let fd = writesub.file_descriptor; - poll.subscribe_write(u32::from(fd), sub.userdata.into())?; + if sub_fds.contains(&fd) { + return Err(Error::invalid_argument() + .context("Fd can be subscribed to at most once per poll")); + } else { + sub_fds.insert(fd); + } + let file_ref = table + .get_file_mut(u32::from(fd))? + .get_cap(FileCaps::POLL_READWRITE)?; + write_refs.push((file_ref, sub.userdata.into())); } } } + for (f, ud) in read_refs.iter_mut() { + poll.subscribe_read(f.deref_mut(), *ud); + } + for (f, ud) in write_refs.iter_mut() { + poll.subscribe_write(f.deref_mut(), *ud); + } + self.sched.poll_oneoff(&poll).await?; let results = poll.results();