diff --git a/crates/wasi-common/cap-std-sync/src/sched/unix.rs b/crates/wasi-common/cap-std-sync/src/sched/unix.rs index 37fbbc1c0f..030c1bcf68 100644 --- a/crates/wasi-common/cap-std-sync/src/sched/unix.rs +++ b/crates/wasi-common/cap-std-sync/src/sched/unix.rs @@ -32,14 +32,14 @@ impl WasiSched for SyncSched { for s in poll.rw_subscriptions() { match s { Subscription::Read(f) => { - let raw_fd = wasi_file_raw_fd(f.file.deref()).ok_or( + let raw_fd = wasi_file_raw_fd(f.file()?.deref()).ok_or( Error::invalid_argument().context("read subscription fd downcast failed"), )?; pollfds.push(unsafe { PollFd::new(raw_fd, PollFlags::POLLIN) }); } Subscription::Write(f) => { - let raw_fd = wasi_file_raw_fd(f.file.deref()).ok_or( + let raw_fd = wasi_file_raw_fd(f.file()?.deref()).ok_or( Error::invalid_argument().context("write subscription fd downcast failed"), )?; pollfds.push(unsafe { PollFd::new(raw_fd, PollFlags::POLLOUT) }); @@ -79,7 +79,11 @@ impl WasiSched for SyncSched { if let Some(revents) = pollfd.revents() { let (nbytes, rwsub) = match rwsub { Subscription::Read(sub) => { - let ready = sub.file.num_ready_bytes().await?; + let ready = sub + .file() + .expect("validated file already") + .num_ready_bytes() + .await?; (std::cmp::max(ready, 1), sub) } Subscription::Write(sub) => (0, sub), diff --git a/crates/wasi-common/src/sched.rs b/crates/wasi-common/src/sched.rs index b9a8062a4b..984b8d9848 100644 --- a/crates/wasi-common/src/sched.rs +++ b/crates/wasi-common/src/sched.rs @@ -1,8 +1,8 @@ use crate::clocks::WasiMonotonicClock; -use crate::file::WasiFile; -use crate::Error; +use crate::table::Table; +use crate::{Error, ErrorExt}; use cap_std::time::Instant; -use std::cell::RefMut; +use std::collections::HashSet; pub mod subscription; pub use cap_std::time::Duration; @@ -29,12 +29,18 @@ impl From for u64 { } pub struct Poll<'a> { + table: &'a Table, + fds: HashSet, subs: Vec<(Subscription<'a>, Userdata)>, } impl<'a> Poll<'a> { - pub fn new() -> Self { - Self { subs: Vec::new() } + pub fn new(table: &'a Table) -> Self { + Self { + table, + fds: HashSet::new(), + subs: Vec::new(), + } } pub fn subscribe_monotonic_clock( &mut self, @@ -52,13 +58,31 @@ impl<'a> Poll<'a> { ud, )); } - pub fn subscribe_read(&mut self, file: RefMut<'a, dyn WasiFile>, ud: Userdata) { + pub fn subscribe_read(&mut self, fd: u32, ud: Userdata) -> Result<(), Error> { + if self.fds.contains(&fd) { + return Err( + Error::invalid_argument().context("Fd can be subscribed to at most once per poll") + ); + } else { + self.fds.insert(fd); + } self.subs - .push((Subscription::Read(RwSubscription::new(file)), ud)); + .push((Subscription::Read(RwSubscription::new(self.table, fd)?), ud)); + Ok(()) } - pub fn subscribe_write(&mut self, file: RefMut<'a, dyn WasiFile>, ud: Userdata) { - self.subs - .push((Subscription::Write(RwSubscription::new(file)), ud)); + pub fn subscribe_write(&mut self, fd: u32, ud: Userdata) -> Result<(), Error> { + if self.fds.contains(&fd) { + return Err( + Error::invalid_argument().context("Fd can be subscribed to at most once per poll") + ); + } else { + self.fds.insert(fd); + } + self.subs.push(( + Subscription::Write(RwSubscription::new(self.table, fd)?), + ud, + )); + Ok(()) } pub fn results(self) -> Vec<(SubscriptionResult, Userdata)> { self.subs diff --git a/crates/wasi-common/src/sched/subscription.rs b/crates/wasi-common/src/sched/subscription.rs index 3347d3a826..73575ae69c 100644 --- a/crates/wasi-common/src/sched/subscription.rs +++ b/crates/wasi-common/src/sched/subscription.rs @@ -1,5 +1,6 @@ use crate::clocks::WasiMonotonicClock; -use crate::file::WasiFile; +use crate::file::{FileCaps, FileEntryMutExt, TableFileExt, WasiFile}; +use crate::table::Table; use crate::Error; use bitflags::bitflags; use cap_std::time::{Duration, Instant}; @@ -12,16 +13,29 @@ bitflags! { } pub struct RwSubscription<'a> { - pub file: RefMut<'a, dyn WasiFile>, + table: &'a Table, + fd: u32, status: Cell>>, } impl<'a> RwSubscription<'a> { - pub fn new(file: RefMut<'a, dyn WasiFile>) -> Self { - Self { - file, + /// Create an RwSubscription. This constructor checks to make sure the file we need exists, and + /// has the correct rights. But, we can't hold onto the WasiFile RefMut inside this structure + /// (Pat can't convince borrow checker, either not clever enough or a rustc bug), so we need to + /// re-borrow at use time. + pub fn new(table: &'a Table, fd: u32) -> Result { + let _ = table.get_file_mut(fd)?.get_cap(FileCaps::POLL_READWRITE)?; + Ok(Self { + table, + fd, status: Cell::new(None), - } + }) + } + /// This accessor could fail if there is an outstanding borrow of the file. + pub fn file(&self) -> Result, Error> { + self.table + .get_file_mut(self.fd)? + .get_cap(FileCaps::POLL_READWRITE) } pub fn complete(&self, size: u64, flags: RwEventFlags) { self.status.set(Some(Ok((size, flags)))) diff --git a/crates/wasi-common/src/snapshots/preview_0.rs b/crates/wasi-common/src/snapshots/preview_0.rs index 775962afef..6028bf12da 100644 --- a/crates/wasi-common/src/snapshots/preview_0.rs +++ b/crates/wasi-common/src/snapshots/preview_0.rs @@ -1,4 +1,4 @@ -use crate::file::{FileCaps, FileEntryExt, FileEntryMutExt, TableFileExt}; +use crate::file::{FileCaps, FileEntryExt, TableFileExt}; use crate::sched::{ subscription::{RwEventFlags, SubscriptionResult}, Poll, @@ -7,7 +7,6 @@ use crate::snapshots::preview_1::types as snapshot1_types; use crate::snapshots::preview_1::wasi_snapshot_preview1::WasiSnapshotPreview1 as Snapshot1; use crate::{Error, ErrorExt, WasiCtx}; use cap_std::time::Duration; -use std::collections::HashSet; use std::convert::{TryFrom, TryInto}; use std::io::{IoSlice, IoSliceMut}; use std::ops::Deref; @@ -779,8 +778,7 @@ impl wasi_unstable::WasiUnstable for WasiCtx { } let table = self.table(); - let mut subscribed_fds = HashSet::new(); - let mut poll = Poll::new(); + let mut poll = Poll::new(&table); let subs = subs.as_array(nsubscriptions); for sub_elem in subs.iter() { @@ -818,29 +816,11 @@ impl wasi_unstable::WasiUnstable for WasiCtx { }, types::SubscriptionU::FdRead(readsub) => { let fd = readsub.file_descriptor; - if subscribed_fds.contains(&fd) { - Err(Error::invalid_argument() - .context("Fd can be subscribed to at most once per poll_oneoff"))?; - } else { - subscribed_fds.insert(fd); - } - let file = table - .get_file_mut(u32::from(fd))? - .get_cap(FileCaps::POLL_READWRITE)?; - poll.subscribe_read(file, sub.userdata.into()); + poll.subscribe_read(u32::from(fd), sub.userdata.into())?; } types::SubscriptionU::FdWrite(writesub) => { let fd = writesub.file_descriptor; - if subscribed_fds.contains(&fd) { - Err(Error::invalid_argument() - .context("Fd can be subscribed to at most once per poll_oneoff"))?; - } else { - subscribed_fds.insert(fd); - } - let file = table - .get_file_mut(u32::from(fd))? - .get_cap(FileCaps::POLL_READWRITE)?; - poll.subscribe_write(file, sub.userdata.into()); + poll.subscribe_write(u32::from(fd), sub.userdata.into())?; } } } diff --git a/crates/wasi-common/src/snapshots/preview_1.rs b/crates/wasi-common/src/snapshots/preview_1.rs index 04ba4313d0..2487ef339e 100644 --- a/crates/wasi-common/src/snapshots/preview_1.rs +++ b/crates/wasi-common/src/snapshots/preview_1.rs @@ -13,7 +13,6 @@ use crate::{ use anyhow::Context; use cap_std::time::{Duration, SystemClock}; use std::cell::{Ref, RefMut}; -use std::collections::HashSet; use std::convert::{TryFrom, TryInto}; use std::io::{IoSlice, IoSliceMut}; use std::ops::{Deref, DerefMut}; @@ -971,8 +970,7 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { } let table = self.table(); - let mut subscribed_fds = HashSet::new(); - let mut poll = Poll::new(); + let mut poll = Poll::new(&table); let subs = subs.as_array(nsubscriptions); for sub_elem in subs.iter() { @@ -1010,29 +1008,11 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { }, types::SubscriptionU::FdRead(readsub) => { let fd = readsub.file_descriptor; - if subscribed_fds.contains(&fd) { - return Err(Error::invalid_argument() - .context("Fd can be subscribed to at most once per poll_oneoff")); - } else { - subscribed_fds.insert(fd); - } - let file = table - .get_file_mut(u32::from(fd))? - .get_cap(FileCaps::POLL_READWRITE)?; - poll.subscribe_read(file, sub.userdata.into()); + poll.subscribe_read(u32::from(fd), sub.userdata.into())?; } types::SubscriptionU::FdWrite(writesub) => { let fd = writesub.file_descriptor; - if subscribed_fds.contains(&fd) { - return Err(Error::invalid_argument() - .context("Fd can be subscribed to at most once per poll_oneoff")); - } else { - subscribed_fds.insert(fd); - } - let file = table - .get_file_mut(u32::from(fd))? - .get_cap(FileCaps::POLL_READWRITE)?; - poll.subscribe_write(file, sub.userdata.into()); + poll.subscribe_write(u32::from(fd), sub.userdata.into())?; } } } diff --git a/crates/wasi-common/tokio/src/sched/unix.rs b/crates/wasi-common/tokio/src/sched/unix.rs index 7445394840..f286b0d29a 100644 --- a/crates/wasi-common/tokio/src/sched/unix.rs +++ b/crates/wasi-common/tokio/src/sched/unix.rs @@ -1,6 +1,6 @@ use cap_std::time::Duration; use std::convert::TryInto; -use std::future::{Future, Poll as FPoll}; +use std::future::Future; use std::ops::Deref; use std::pin::Pin; use std::task::Context; @@ -23,15 +23,15 @@ pub async fn poll_oneoff<'a>(poll: &'_ Poll<'a>) -> Result<(), Error> { match s { Subscription::Read(f) => { futures.push(Box::pin(async move { - f.file.readable().await?; - f.complete(f.file.num_ready_bytes().await?, RwEventFlags::empty()); + f.file()?.readable().await?; + f.complete(f.file()?.num_ready_bytes().await?, RwEventFlags::empty()); Ok(()) })); } Subscription::Write(f) => { futures.push(Box::pin(async move { - f.file.writable().await?; + f.file()?.writable().await?; f.complete(0, RwEventFlags::empty()); Ok(()) }));