jump through enough hoops for the poll lifetime to work out

you program rust for a few years and you think you're done tearing your
hair out over lifetimes, well, you'll find yourself wrong
This commit is contained in:
Pat Hickey
2021-04-29 16:44:45 -07:00
parent ab4f5bb674
commit b7efcbe80f
8 changed files with 99 additions and 87 deletions

View File

@@ -1,6 +1,5 @@
use cap_std::time::Duration; use cap_std::time::Duration;
use std::convert::TryInto; use std::convert::TryInto;
use std::ops::Deref;
use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::io::{AsRawFd, RawFd};
use wasi_common::{ use wasi_common::{
file::WasiFile, file::WasiFile,
@@ -23,23 +22,22 @@ impl SyncSched {
#[wiggle::async_trait] #[wiggle::async_trait]
impl WasiSched for SyncSched { impl WasiSched for SyncSched {
async fn poll_oneoff<'a>(&self, poll: &'_ Poll<'a>) -> Result<(), Error> { async fn poll_oneoff<'a>(&self, poll: &'a Poll<'a>) -> Result<(), Error> {
if poll.is_empty() { if poll.is_empty() {
return Ok(()); return Ok(());
} }
let mut pollfds = Vec::new(); let mut pollfds = Vec::new();
let timeout = poll.earliest_clock_deadline();
for s in poll.rw_subscriptions() { for s in poll.rw_subscriptions() {
match s { match s {
Subscription::Read(f) => { Subscription::Read(f) => {
let raw_fd = wasi_file_raw_fd(f.file()?.deref()).ok_or( let raw_fd = wasi_file_raw_fd(f.file).ok_or(
Error::invalid_argument().context("read subscription fd downcast failed"), Error::invalid_argument().context("read subscription fd downcast failed"),
)?; )?;
pollfds.push(unsafe { PollFd::new(raw_fd, PollFlags::POLLIN) }); pollfds.push(unsafe { PollFd::new(raw_fd, PollFlags::POLLIN) });
} }
Subscription::Write(f) => { Subscription::Write(f) => {
let raw_fd = wasi_file_raw_fd(f.file()?.deref()).ok_or( let raw_fd = wasi_file_raw_fd(f.file).ok_or(
Error::invalid_argument().context("write subscription fd downcast failed"), Error::invalid_argument().context("write subscription fd downcast failed"),
)?; )?;
pollfds.push(unsafe { PollFd::new(raw_fd, PollFlags::POLLOUT) }); pollfds.push(unsafe { PollFd::new(raw_fd, PollFlags::POLLOUT) });
@@ -49,7 +47,7 @@ impl WasiSched for SyncSched {
} }
let ready = loop { let ready = loop {
let poll_timeout = if let Some(t) = timeout { let poll_timeout = if let Some(t) = poll.earliest_clock_deadline() {
let duration = t.duration_until().unwrap_or(Duration::from_secs(0)); let duration = t.duration_until().unwrap_or(Duration::from_secs(0));
(duration.as_millis() + 1) // XXX try always rounding up? (duration.as_millis() + 1) // XXX try always rounding up?
.try_into() .try_into()
@@ -79,11 +77,7 @@ impl WasiSched for SyncSched {
if let Some(revents) = pollfd.revents() { if let Some(revents) = pollfd.revents() {
let (nbytes, rwsub) = match rwsub { let (nbytes, rwsub) = match rwsub {
Subscription::Read(sub) => { Subscription::Read(sub) => {
let ready = sub let ready = sub.file.num_ready_bytes().await?;
.file()
.expect("validated file already")
.num_ready_bytes()
.await?;
(std::cmp::max(ready, 1), sub) (std::cmp::max(ready, 1), sub)
} }
Subscription::Write(sub) => (0, sub), Subscription::Write(sub) => (0, sub),
@@ -101,7 +95,7 @@ impl WasiSched for SyncSched {
} }
} }
} else { } else {
timeout poll.earliest_clock_deadline()
.expect("timed out") .expect("timed out")
.result() .result()
.expect("timer deadline is past") .expect("timer deadline is past")

View File

@@ -23,7 +23,7 @@
//! The real value of using `anyhow::Error` here is being able to use //! The real value of using `anyhow::Error` here is being able to use
//! `anyhow::Result::context` to aid in debugging of errors. //! `anyhow::Result::context` to aid in debugging of errors.
pub use anyhow::Error; pub use anyhow::{Context, Error};
/// Internal error type for the `wasi-common` crate. /// Internal error type for the `wasi-common` crate.
/// Contains variants of the WASI `$errno` type are added according to what is actually used internally by /// Contains variants of the WASI `$errno` type are added according to what is actually used internally by

View File

@@ -66,7 +66,7 @@ pub use cap_rand::RngCore;
pub use clocks::{SystemTimeSpec, WasiClocks, WasiMonotonicClock, WasiSystemClock}; pub use clocks::{SystemTimeSpec, WasiClocks, WasiMonotonicClock, WasiSystemClock};
pub use ctx::{WasiCtx, WasiCtxBuilder}; pub use ctx::{WasiCtx, WasiCtxBuilder};
pub use dir::WasiDir; pub use dir::WasiDir;
pub use error::{Error, ErrorExt, ErrorKind}; pub use error::{Context, Error, ErrorExt, ErrorKind};
pub use file::WasiFile; pub use file::WasiFile;
pub use sched::{Poll, WasiSched}; pub use sched::{Poll, WasiSched};
pub use string_array::StringArrayError; pub use string_array::StringArrayError;

View File

@@ -1,8 +1,7 @@
use crate::clocks::WasiMonotonicClock; use crate::clocks::WasiMonotonicClock;
use crate::table::Table; use crate::file::WasiFile;
use crate::{Error, ErrorExt}; use crate::Error;
use cap_std::time::Instant; use cap_std::time::Instant;
use std::collections::HashSet;
pub mod subscription; pub mod subscription;
pub use cap_std::time::Duration; pub use cap_std::time::Duration;
@@ -10,11 +9,12 @@ use subscription::{MonotonicClockSubscription, RwSubscription, Subscription, Sub
#[wiggle::async_trait] #[wiggle::async_trait]
pub trait WasiSched { pub trait WasiSched {
async fn poll_oneoff<'a>(&self, poll: &Poll<'a>) -> Result<(), Error>; async fn poll_oneoff<'a>(&self, poll: &'a Poll<'a>) -> Result<(), Error>;
async fn sched_yield(&self) -> Result<(), Error>; async fn sched_yield(&self) -> Result<(), Error>;
async fn sleep(&self, duration: Duration) -> Result<(), Error>; async fn sleep(&self, duration: Duration) -> Result<(), Error>;
} }
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct Userdata(u64); pub struct Userdata(u64);
impl From<u64> for Userdata { impl From<u64> for Userdata {
fn from(u: u64) -> Userdata { fn from(u: u64) -> Userdata {
@@ -28,19 +28,15 @@ impl From<Userdata> for u64 {
} }
} }
pub type PollResults = Vec<(SubscriptionResult, Userdata)>;
pub struct Poll<'a> { pub struct Poll<'a> {
table: &'a Table,
fds: HashSet<u32>,
subs: Vec<(Subscription<'a>, Userdata)>, subs: Vec<(Subscription<'a>, Userdata)>,
} }
impl<'a> Poll<'a> { impl<'a> Poll<'a> {
pub fn new(table: &'a Table) -> Self { pub fn new() -> Self {
Self { Self { subs: Vec::new() }
table,
fds: HashSet::new(),
subs: Vec::new(),
}
} }
pub fn subscribe_monotonic_clock( pub fn subscribe_monotonic_clock(
&mut self, &mut self,
@@ -58,36 +54,18 @@ impl<'a> Poll<'a> {
ud, ud,
)); ));
} }
pub fn subscribe_read(&mut self, fd: u32, ud: Userdata) -> Result<(), Error> { pub fn subscribe_read(&mut self, file: &'a mut dyn WasiFile, ud: Userdata) {
if self.fds.contains(&fd) {
return Err(
Error::invalid_argument().context("Fd can be subscribed to at most once per poll")
);
} else {
self.fds.insert(fd);
}
self.subs self.subs
.push((Subscription::Read(RwSubscription::new(self.table, fd)?), ud)); .push((Subscription::Read(RwSubscription::new(file)), ud));
Ok(())
} }
pub fn subscribe_write(&mut self, fd: u32, ud: Userdata) -> Result<(), Error> { pub fn subscribe_write(&mut self, file: &'a mut dyn WasiFile, ud: Userdata) {
if self.fds.contains(&fd) {
return Err(
Error::invalid_argument().context("Fd can be subscribed to at most once per poll")
);
} else {
self.fds.insert(fd);
}
self.subs.push((
Subscription::Write(RwSubscription::new(self.table, fd)?),
ud,
));
Ok(())
}
pub fn results(self) -> Vec<(SubscriptionResult, Userdata)> {
self.subs self.subs
.into_iter() .push((Subscription::Write(RwSubscription::new(file)), ud));
.filter_map(|(s, ud)| SubscriptionResult::from_subscription(s).map(|r| (r, ud))) }
pub fn results(&self) -> Vec<(SubscriptionResult, Userdata)> {
self.subs
.iter()
.filter_map(|(s, ud)| SubscriptionResult::from_subscription(s).map(|r| (r, *ud)))
.collect() .collect()
} }
pub fn is_empty(&self) -> bool { pub fn is_empty(&self) -> bool {

View File

@@ -1,10 +1,9 @@
use crate::clocks::WasiMonotonicClock; use crate::clocks::WasiMonotonicClock;
use crate::file::{FileCaps, FileEntryMutExt, TableFileExt, WasiFile}; use crate::file::WasiFile;
use crate::table::Table;
use crate::Error; use crate::Error;
use bitflags::bitflags; use bitflags::bitflags;
use cap_std::time::{Duration, Instant}; use cap_std::time::{Duration, Instant};
use std::cell::{Cell, RefMut}; use std::cell::Cell;
bitflags! { bitflags! {
pub struct RwEventFlags: u32 { pub struct RwEventFlags: u32 {
@@ -13,29 +12,16 @@ bitflags! {
} }
pub struct RwSubscription<'a> { pub struct RwSubscription<'a> {
table: &'a Table, pub file: &'a mut dyn WasiFile,
fd: u32,
status: Cell<Option<Result<(u64, RwEventFlags), Error>>>, status: Cell<Option<Result<(u64, RwEventFlags), Error>>>,
} }
impl<'a> RwSubscription<'a> { impl<'a> RwSubscription<'a> {
/// Create an RwSubscription. This constructor checks to make sure the file we need exists, and pub fn new(file: &'a mut dyn WasiFile) -> Self {
/// has the correct rights. But, we can't hold onto the WasiFile RefMut inside this structure Self {
/// (Pat can't convince borrow checker, either not clever enough or a rustc bug), so we need to file,
/// re-borrow at use time.
pub fn new(table: &'a Table, fd: u32) -> Result<Self, Error> {
let _ = table.get_file_mut(fd)?.get_cap(FileCaps::POLL_READWRITE)?;
Ok(Self {
table,
fd,
status: Cell::new(None), status: Cell::new(None),
}) }
}
/// This accessor could fail if there is an outstanding borrow of the file.
pub fn file(&self) -> Result<RefMut<'a, dyn WasiFile>, Error> {
self.table
.get_file_mut(self.fd)?
.get_cap(FileCaps::POLL_READWRITE)
} }
pub fn complete(&self, size: u64, flags: RwEventFlags) { pub fn complete(&self, size: u64, flags: RwEventFlags) {
self.status.set(Some(Ok((size, flags)))) self.status.set(Some(Ok((size, flags))))
@@ -43,8 +29,8 @@ impl<'a> RwSubscription<'a> {
pub fn error(&self, error: Error) { pub fn error(&self, error: Error) {
self.status.set(Some(Err(error))) self.status.set(Some(Err(error)))
} }
pub fn result(self) -> Option<Result<(u64, RwEventFlags), Error>> { pub fn result(&self) -> Option<Result<(u64, RwEventFlags), Error>> {
self.status.into_inner() self.status.take()
} }
} }
@@ -83,7 +69,7 @@ pub enum SubscriptionResult {
} }
impl SubscriptionResult { impl SubscriptionResult {
pub fn from_subscription(s: Subscription) -> Option<SubscriptionResult> { pub fn from_subscription(s: &Subscription) -> Option<SubscriptionResult> {
match s { match s {
Subscription::Read(s) => s.result().map(SubscriptionResult::Read), Subscription::Read(s) => s.result().map(SubscriptionResult::Read),
Subscription::Write(s) => s.result().map(SubscriptionResult::Write), Subscription::Write(s) => s.result().map(SubscriptionResult::Write),

View File

@@ -1,12 +1,14 @@
use crate::file::{FileCaps, FileEntryExt, TableFileExt}; use crate::file::{FileCaps, FileEntryExt, FileEntryMutExt, TableFileExt, WasiFile};
use crate::sched::{ use crate::sched::{
subscription::{RwEventFlags, SubscriptionResult}, subscription::{RwEventFlags, SubscriptionResult},
Poll, Poll, Userdata,
}; };
use crate::snapshots::preview_1::types as snapshot1_types; use crate::snapshots::preview_1::types as snapshot1_types;
use crate::snapshots::preview_1::wasi_snapshot_preview1::WasiSnapshotPreview1 as Snapshot1; use crate::snapshots::preview_1::wasi_snapshot_preview1::WasiSnapshotPreview1 as Snapshot1;
use crate::{Error, ErrorExt, WasiCtx}; use crate::{Error, ErrorExt, WasiCtx};
use cap_std::time::Duration; use cap_std::time::Duration;
use std::cell::RefMut;
use std::collections::HashSet;
use std::convert::{TryFrom, TryInto}; use std::convert::{TryFrom, TryInto};
use std::io::{IoSlice, IoSliceMut}; use std::io::{IoSlice, IoSliceMut};
use std::ops::Deref; use std::ops::Deref;
@@ -778,7 +780,11 @@ impl wasi_unstable::WasiUnstable for WasiCtx {
} }
let table = self.table(); let table = self.table();
let mut poll = Poll::new(&table); let mut sub_fds: HashSet<types::Fd> = HashSet::new();
// We need these refmuts to outlive Poll, which will hold the &mut dyn WasiFile inside
let mut read_refs: Vec<(RefMut<'_, dyn WasiFile>, Userdata)> = Vec::new();
let mut write_refs: Vec<(RefMut<'_, dyn WasiFile>, Userdata)> = Vec::new();
let mut poll = Poll::new();
let subs = subs.as_array(nsubscriptions); let subs = subs.as_array(nsubscriptions);
for sub_elem in subs.iter() { for sub_elem in subs.iter() {
@@ -816,11 +822,29 @@ impl wasi_unstable::WasiUnstable for WasiCtx {
}, },
types::SubscriptionU::FdRead(readsub) => { types::SubscriptionU::FdRead(readsub) => {
let fd = readsub.file_descriptor; let fd = readsub.file_descriptor;
poll.subscribe_read(u32::from(fd), sub.userdata.into())?; if sub_fds.contains(&fd) {
return Err(Error::invalid_argument()
.context("Fd can be subscribed to at most once per poll"));
} else {
sub_fds.insert(fd);
}
let file_ref = table
.get_file_mut(u32::from(fd))?
.get_cap(FileCaps::POLL_READWRITE)?;
read_refs.push((file_ref, sub.userdata.into()));
} }
types::SubscriptionU::FdWrite(writesub) => { types::SubscriptionU::FdWrite(writesub) => {
let fd = writesub.file_descriptor; let fd = writesub.file_descriptor;
poll.subscribe_write(u32::from(fd), sub.userdata.into())?; if sub_fds.contains(&fd) {
return Err(Error::invalid_argument()
.context("Fd can be subscribed to at most once per poll"));
} else {
sub_fds.insert(fd);
}
let file_ref = table
.get_file_mut(u32::from(fd))?
.get_cap(FileCaps::POLL_READWRITE)?;
write_refs.push((file_ref, sub.userdata.into()));
} }
} }
} }

View File

@@ -2,17 +2,18 @@ use crate::{
dir::{DirCaps, DirEntry, DirEntryExt, DirFdStat, ReaddirCursor, ReaddirEntity, TableDirExt}, dir::{DirCaps, DirEntry, DirEntryExt, DirFdStat, ReaddirCursor, ReaddirEntity, TableDirExt},
file::{ file::{
Advice, FdFlags, FdStat, FileCaps, FileEntry, FileEntryExt, FileEntryMutExt, FileType, Advice, FdFlags, FdStat, FileCaps, FileEntry, FileEntryExt, FileEntryMutExt, FileType,
Filestat, OFlags, TableFileExt, Filestat, OFlags, TableFileExt, WasiFile,
}, },
sched::{ sched::{
subscription::{RwEventFlags, SubscriptionResult}, subscription::{RwEventFlags, SubscriptionResult},
Poll, Poll, Userdata,
}, },
Error, ErrorExt, ErrorKind, SystemTimeSpec, WasiCtx, Error, ErrorExt, ErrorKind, SystemTimeSpec, WasiCtx,
}; };
use anyhow::Context; use anyhow::Context;
use cap_std::time::{Duration, SystemClock}; use cap_std::time::{Duration, SystemClock};
use std::cell::{Ref, RefMut}; use std::cell::{Ref, RefMut};
use std::collections::HashSet;
use std::convert::{TryFrom, TryInto}; use std::convert::{TryFrom, TryInto};
use std::io::{IoSlice, IoSliceMut}; use std::io::{IoSlice, IoSliceMut};
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
@@ -970,7 +971,11 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx {
} }
let table = self.table(); let table = self.table();
let mut poll = Poll::new(&table); let mut sub_fds: HashSet<types::Fd> = HashSet::new();
// We need these refmuts to outlive Poll, which will hold the &mut dyn WasiFile inside
let mut read_refs: Vec<(RefMut<'_, dyn WasiFile>, Userdata)> = Vec::new();
let mut write_refs: Vec<(RefMut<'_, dyn WasiFile>, Userdata)> = Vec::new();
let mut poll = Poll::new();
let subs = subs.as_array(nsubscriptions); let subs = subs.as_array(nsubscriptions);
for sub_elem in subs.iter() { for sub_elem in subs.iter() {
@@ -1008,15 +1013,40 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx {
}, },
types::SubscriptionU::FdRead(readsub) => { types::SubscriptionU::FdRead(readsub) => {
let fd = readsub.file_descriptor; let fd = readsub.file_descriptor;
poll.subscribe_read(u32::from(fd), sub.userdata.into())?; if sub_fds.contains(&fd) {
return Err(Error::invalid_argument()
.context("Fd can be subscribed to at most once per poll"));
} else {
sub_fds.insert(fd);
}
let file_ref = table
.get_file_mut(u32::from(fd))?
.get_cap(FileCaps::POLL_READWRITE)?;
read_refs.push((file_ref, sub.userdata.into()));
} }
types::SubscriptionU::FdWrite(writesub) => { types::SubscriptionU::FdWrite(writesub) => {
let fd = writesub.file_descriptor; let fd = writesub.file_descriptor;
poll.subscribe_write(u32::from(fd), sub.userdata.into())?; if sub_fds.contains(&fd) {
return Err(Error::invalid_argument()
.context("Fd can be subscribed to at most once per poll"));
} else {
sub_fds.insert(fd);
}
let file_ref = table
.get_file_mut(u32::from(fd))?
.get_cap(FileCaps::POLL_READWRITE)?;
write_refs.push((file_ref, sub.userdata.into()));
} }
} }
} }
for (f, ud) in read_refs.iter_mut() {
poll.subscribe_read(f.deref_mut(), *ud);
}
for (f, ud) in write_refs.iter_mut() {
poll.subscribe_write(f.deref_mut(), *ud);
}
self.sched.poll_oneoff(&poll).await?; self.sched.poll_oneoff(&poll).await?;
let results = poll.results(); let results = poll.results();