diff --git a/crates/wasi-common/tokio/Cargo.toml b/crates/wasi-common/tokio/Cargo.toml index f480de0635..d67fc72914 100644 --- a/crates/wasi-common/tokio/Cargo.toml +++ b/crates/wasi-common/tokio/Cargo.toml @@ -22,6 +22,15 @@ cap-time-ext = "0.13.7" fs-set-times = "0.3.1" unsafe-io = "0.6.2" system-interface = { version = "0.6.3", features = ["cap_std_impls"] } +tracing = "0.1.19" +bitflags = "1.2" + +[target.'cfg(unix)'.dependencies] +libc = "0.2" + +[target.'cfg(windows)'.dependencies] +winapi = "0.3" +lazy_static = "1.4" [dev-dependencies] tempfile = "3.1.0" diff --git a/crates/wasi-common/tokio/src/lib.rs b/crates/wasi-common/tokio/src/lib.rs index 7ae8849259..a5ff5f8ca7 100644 --- a/crates/wasi-common/tokio/src/lib.rs +++ b/crates/wasi-common/tokio/src/lib.rs @@ -1,5 +1,6 @@ mod dir; mod file; +mod sched; use std::cell::RefCell; use std::path::Path; @@ -7,27 +8,7 @@ use std::rc::Rc; pub use wasi_cap_std_sync::{clocks_ctx, random_ctx, Dir}; use wasi_common::{Error, Table, WasiCtx}; -pub fn sched_ctx() -> Box { - use wasi_common::sched::{Duration, Poll, WasiSched}; - struct AsyncSched; - - #[wiggle::async_trait] - impl WasiSched for AsyncSched { - async fn poll_oneoff<'a>(&self, _poll: &'_ Poll<'a>) -> Result<(), Error> { - todo!() - } - async fn sched_yield(&self) -> Result<(), Error> { - tokio::task::yield_now().await; - Ok(()) - } - async fn sleep(&self, duration: Duration) -> Result<(), Error> { - tokio::time::sleep(duration).await; - Ok(()) - } - } - - Box::new(AsyncSched) -} +use crate::sched::sched_ctx; pub struct WasiCtxBuilder(wasi_common::WasiCtxBuilder); diff --git a/crates/wasi-common/tokio/src/sched.rs b/crates/wasi-common/tokio/src/sched.rs new file mode 100644 index 0000000000..fd41f32996 --- /dev/null +++ b/crates/wasi-common/tokio/src/sched.rs @@ -0,0 +1,35 @@ +#[cfg(unix)] +mod unix; +#[cfg(unix)] +use unix::poll_oneoff; + +#[cfg(windows)] +mod windows; +#[cfg(windows)] +use windows::poll_oneoff; + +use wasi_common::{ + sched::{Duration, Poll, WasiSched}, + Error, +}; + +pub fn sched_ctx() -> Box { + struct AsyncSched; + + #[wiggle::async_trait] + impl WasiSched for AsyncSched { + async fn poll_oneoff<'a>(&self, poll: &'_ Poll<'a>) -> Result<(), Error> { + poll_oneoff(poll).await + } + async fn sched_yield(&self) -> Result<(), Error> { + tokio::task::yield_now().await; + Ok(()) + } + async fn sleep(&self, duration: Duration) -> Result<(), Error> { + tokio::time::sleep(duration).await; + Ok(()) + } + } + + Box::new(AsyncSched) +} diff --git a/crates/wasi-common/tokio/src/sched/unix.rs b/crates/wasi-common/tokio/src/sched/unix.rs new file mode 100644 index 0000000000..83ddb2c839 --- /dev/null +++ b/crates/wasi-common/tokio/src/sched/unix.rs @@ -0,0 +1,176 @@ +use cap_std::time::Duration; +use std::convert::TryInto; +use std::ops::Deref; +use std::os::unix::io::{AsRawFd, RawFd}; +use wasi_common::{ + file::WasiFile, + sched::{ + subscription::{RwEventFlags, Subscription}, + Poll, + }, + Error, ErrorExt, +}; + +use poll::{PollFd, PollFlags}; + +pub async fn poll_oneoff<'a>(poll: &'_ Poll<'a>) -> Result<(), Error> { + if poll.is_empty() { + return Ok(()); + } + let mut pollfds = Vec::new(); + let timeout = poll.earliest_clock_deadline(); + for s in poll.rw_subscriptions() { + match s { + Subscription::Read(f) => { + let raw_fd = wasi_file_raw_fd(f.file.deref()).ok_or( + Error::invalid_argument().context("read subscription fd downcast failed"), + )?; + pollfds.push(unsafe { PollFd::new(raw_fd, PollFlags::POLLIN) }); + } + + Subscription::Write(f) => { + let raw_fd = wasi_file_raw_fd(f.file.deref()).ok_or( + Error::invalid_argument().context("write subscription fd downcast failed"), + )?; + pollfds.push(unsafe { PollFd::new(raw_fd, PollFlags::POLLOUT) }); + } + Subscription::MonotonicClock { .. } => unreachable!(), + } + } + + let ready = loop { + let poll_timeout = if let Some(t) = timeout { + let duration = t.duration_until().unwrap_or(Duration::from_secs(0)); + (duration.as_millis() + 1) // XXX try always rounding up? + .try_into() + .map_err(|_| Error::overflow().context("poll timeout"))? + } else { + libc::c_int::max_value() + }; + tracing::debug!( + poll_timeout = tracing::field::debug(poll_timeout), + poll_fds = tracing::field::debug(&pollfds), + "poll" + ); + match poll::poll(&mut pollfds, poll_timeout) { + Ok(ready) => break ready, + Err(_) => { + let last_err = std::io::Error::last_os_error(); + if last_err.raw_os_error().unwrap() == libc::EINTR { + continue; + } else { + return Err(last_err.into()); + } + } + } + }; + if ready > 0 { + for (rwsub, pollfd) in poll.rw_subscriptions().zip(pollfds.into_iter()) { + if let Some(revents) = pollfd.revents() { + let (nbytes, rwsub) = match rwsub { + Subscription::Read(sub) => { + let ready = sub.file.num_ready_bytes().await?; + (std::cmp::max(ready, 1), sub) + } + Subscription::Write(sub) => (0, sub), + _ => unreachable!(), + }; + if revents.contains(PollFlags::POLLNVAL) { + rwsub.error(Error::badf()); + } else if revents.contains(PollFlags::POLLERR) { + rwsub.error(Error::io()); + } else if revents.contains(PollFlags::POLLHUP) { + rwsub.complete(nbytes, RwEventFlags::HANGUP); + } else { + rwsub.complete(nbytes, RwEventFlags::empty()); + }; + } + } + } else { + timeout + .expect("timed out") + .result() + .expect("timer deadline is past") + .unwrap() + } + Ok(()) +} + +fn wasi_file_raw_fd(f: &dyn WasiFile) -> Option { + let a = f.as_any(); + if a.is::() { + Some(a.downcast_ref::().unwrap().as_raw_fd()) + } else if a.is::() { + Some(a.downcast_ref::().unwrap().as_raw_fd()) + } else if a.is::() { + Some( + a.downcast_ref::() + .unwrap() + .as_raw_fd(), + ) + } else if a.is::() { + Some( + a.downcast_ref::() + .unwrap() + .as_raw_fd(), + ) + } else { + None + } +} + +mod poll { + use bitflags::bitflags; + use std::convert::TryInto; + use std::os::unix::io::RawFd; + + bitflags! { + pub struct PollFlags: libc::c_short { + const POLLIN = libc::POLLIN; + const POLLPRI = libc::POLLPRI; + const POLLOUT = libc::POLLOUT; + const POLLRDNORM = libc::POLLRDNORM; + const POLLWRNORM = libc::POLLWRNORM; + const POLLRDBAND = libc::POLLRDBAND; + const POLLWRBAND = libc::POLLWRBAND; + const POLLERR = libc::POLLERR; + const POLLHUP = libc::POLLHUP; + const POLLNVAL = libc::POLLNVAL; + } + } + + #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] + #[repr(C)] + pub struct PollFd(libc::pollfd); + + impl PollFd { + pub unsafe fn new(fd: RawFd, events: PollFlags) -> Self { + Self(libc::pollfd { + fd, + events: events.bits(), + revents: PollFlags::empty().bits(), + }) + } + + pub fn revents(self) -> Option { + PollFlags::from_bits(self.0.revents) + } + } + + pub fn poll(fds: &mut [PollFd], timeout: libc::c_int) -> Result { + let nready = unsafe { + libc::poll( + fds.as_mut_ptr() as *mut libc::pollfd, + fds.len() as libc::nfds_t, + timeout, + ) + }; + if nready == -1 { + Err(std::io::Error::last_os_error()) + } else { + // When poll doesn't fail, its return value is a non-negative int, which will + // always be convertable to usize, so we can unwrap() here. + Ok(nready.try_into().unwrap()) + } + } +} diff --git a/crates/wasi-common/tokio/src/sched/windows.rs b/crates/wasi-common/tokio/src/sched/windows.rs new file mode 100644 index 0000000000..a3c15f3c2f --- /dev/null +++ b/crates/wasi-common/tokio/src/sched/windows.rs @@ -0,0 +1,242 @@ +use anyhow::Context; +use std::ops::Deref; +use std::os::windows::io::{AsRawHandle, RawHandle}; +use std::sync::mpsc::{self, Receiver, RecvTimeoutError, Sender, TryRecvError}; +use std::sync::Mutex; +use std::thread; +use std::time::Duration; +use wasi_common::{ + file::WasiFile, + sched::{ + subscription::{RwEventFlags, Subscription}, + Poll, + }, + Error, ErrorExt, +}; +pub async fn poll_oneoff<'a>(poll: &'_ Poll<'a>) -> Result<(), Error> { + if poll.is_empty() { + return Ok(()); + } + + let mut ready = false; + let timeout = poll.earliest_clock_deadline(); + + let mut stdin_read_subs = Vec::new(); + let mut immediate_subs = Vec::new(); + for s in poll.rw_subscriptions() { + match s { + Subscription::Read(r) if r.file.as_any().is::() => { + stdin_read_subs.push(r); + } + Subscription::Read(rw) | Subscription::Write(rw) => { + if wasi_file_raw_handle(rw.file.deref()).is_some() { + immediate_subs.push(s); + } else { + return Err(Error::invalid_argument() + .context("read/write subscription fd downcast failed")); + } + } + Subscription::MonotonicClock { .. } => unreachable!(), + } + } + + if !stdin_read_subs.is_empty() { + let waitmode = if let Some(t) = timeout { + if let Some(duration) = t.duration_until() { + WaitMode::Timeout(duration) + } else { + WaitMode::Immediate + } + } else { + if ready { + WaitMode::Immediate + } else { + WaitMode::Infinite + } + }; + let state = STDIN_POLL + .lock() + .map_err(|_| Error::trap("failed to take lock of STDIN_POLL"))? + .poll(waitmode)?; + for readsub in stdin_read_subs.into_iter() { + match state { + PollState::Ready => { + readsub.complete(1, RwEventFlags::empty()); + ready = true; + } + PollState::NotReady | PollState::TimedOut => {} + PollState::Error(ref e) => { + // Unfortunately, we need to deliver the Error to each of the + // subscriptions, but there is no Clone on std::io::Error. So, we convert it to the + // kind, and then back to std::io::Error, and finally to anyhow::Error. + // When its time to turn this into an errno elsewhere, the error kind will + // be inspected. + let ekind = e.kind(); + let ioerror = std::io::Error::from(ekind); + readsub.error(ioerror.into()); + ready = true; + } + } + } + } + for sub in immediate_subs { + match sub { + Subscription::Read(r) => { + // XXX This doesnt strictly preserve the behavior in the earlier + // implementation, which would always do complete(0) for reads from + // stdout/err. + match r.file.num_ready_bytes().await { + Ok(ready_bytes) => { + r.complete(ready_bytes, RwEventFlags::empty()); + ready = true; + } + Err(e) => { + r.error(e); + ready = true; + } + } + } + Subscription::Write(w) => { + // Everything is always ready for writing, apparently? + w.complete(0, RwEventFlags::empty()); + ready = true; + } + Subscription::MonotonicClock { .. } => unreachable!(), + } + } + + if !ready { + if let Some(t) = timeout { + if let Some(duration) = t.duration_until() { + thread::sleep(duration); + } + } + } + + Ok(()) +} + +fn wasi_file_raw_handle(f: &dyn WasiFile) -> Option { + let a = f.as_any(); + if a.is::() { + Some( + a.downcast_ref::() + .unwrap() + .as_raw_handle(), + ) + } else if a.is::() { + Some( + a.downcast_ref::() + .unwrap() + .as_raw_handle(), + ) + } else if a.is::() { + Some( + a.downcast_ref::() + .unwrap() + .as_raw_handle(), + ) + } else if a.is::() { + Some( + a.downcast_ref::() + .unwrap() + .as_raw_handle(), + ) + } else { + None + } +} + +enum PollState { + Ready, + NotReady, // Not ready, but did not wait + TimedOut, // Not ready, waited until timeout + Error(std::io::Error), +} + +enum WaitMode { + Timeout(Duration), + Infinite, + Immediate, +} + +struct StdinPoll { + request_tx: Sender<()>, + notify_rx: Receiver, +} + +lazy_static::lazy_static! { + static ref STDIN_POLL: Mutex = StdinPoll::new(); +} + +impl StdinPoll { + pub fn new() -> Mutex { + let (request_tx, request_rx) = mpsc::channel(); + let (notify_tx, notify_rx) = mpsc::channel(); + thread::spawn(move || Self::event_loop(request_rx, notify_tx)); + Mutex::new(StdinPoll { + request_tx, + notify_rx, + }) + } + + // This function should not be used directly. + // Correctness of this function crucially depends on the fact that + // mpsc::Receiver is !Sync. + fn poll(&self, wait_mode: WaitMode) -> Result { + match self.notify_rx.try_recv() { + // Clean up possibly unread result from previous poll. + Ok(_) | Err(TryRecvError::Empty) => {} + Err(TryRecvError::Disconnected) => { + return Err(Error::trap("StdinPoll notify_rx channel closed")) + } + } + + // Notify the worker thread to poll stdin + self.request_tx + .send(()) + .context("request_tx channel closed")?; + + // Wait for the worker thread to send a readiness notification + match wait_mode { + WaitMode::Timeout(timeout) => match self.notify_rx.recv_timeout(timeout) { + Ok(r) => Ok(r), + Err(RecvTimeoutError::Timeout) => Ok(PollState::TimedOut), + Err(RecvTimeoutError::Disconnected) => { + Err(Error::trap("StdinPoll notify_rx channel closed")) + } + }, + WaitMode::Infinite => self + .notify_rx + .recv() + .context("StdinPoll notify_rx channel closed"), + WaitMode::Immediate => match self.notify_rx.try_recv() { + Ok(r) => Ok(r), + Err(TryRecvError::Empty) => Ok(PollState::NotReady), + Err(TryRecvError::Disconnected) => { + Err(Error::trap("StdinPoll notify_rx channel closed")) + } + }, + } + } + + fn event_loop(request_rx: Receiver<()>, notify_tx: Sender) -> ! { + use std::io::BufRead; + loop { + // Wait on a request: + request_rx.recv().expect("request_rx channel"); + // Wait for data to appear in stdin. If fill_buf returns any slice, it means + // that either: + // (a) there is some data in stdin, if non-empty, + // (b) EOF was recieved, if its empty + // Linux returns `POLLIN` in both cases, so we imitate this behavior. + let resp = match std::io::stdin().lock().fill_buf() { + Ok(_) => PollState::Ready, + Err(e) => PollState::Error(e), + }; + // Notify about data in stdin. If the read on this channel has timed out, the + // next poller will have to clean the channel. + notify_tx.send(resp).expect("notify_tx channel"); + } + } +}