diff --git a/Cargo.lock b/Cargo.lock index e7d6bac18d..2ac571fc39 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1590,6 +1590,28 @@ dependencies = [ "autocfg 1.0.1", ] +[[package]] +name = "mio" +version = "0.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf80d3e903b34e0bd7282b218398aec54e082c840d9baf8339e0080a0c542956" +dependencies = [ + "libc", + "log", + "miow", + "ntapi", + "winapi", +] + +[[package]] +name = "miow" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9f1c5b025cda876f66ef43a113f91ebc9f4ccef34843000e0adf6ebbab84e21" +dependencies = [ + "winapi", +] + [[package]] name = "more-asserts" version = "0.2.1" @@ -1619,6 +1641,15 @@ dependencies = [ "version_check", ] +[[package]] +name = "ntapi" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f6bb902e437b6d86e03cce10a7e2af662292c5dfef23b65899ea3ac9354ad44" +dependencies = [ + "winapi", +] + [[package]] name = "num-bigint" version = "0.2.6" @@ -2827,7 +2858,9 @@ checksum = "83f0c8e7c0addab50b663055baf787d0af7f413a46e6e7fb9559a4e4db7137a5" dependencies = [ "autocfg 1.0.1", "bytes", + "libc", "memchr", + "mio", "num_cpus", "pin-project-lite", "tokio-macros", diff --git a/crates/wasi-common/tokio/Cargo.toml b/crates/wasi-common/tokio/Cargo.toml index d67fc72914..777578bd4b 100644 --- a/crates/wasi-common/tokio/Cargo.toml +++ b/crates/wasi-common/tokio/Cargo.toml @@ -15,7 +15,7 @@ include = ["src/**/*", "LICENSE" ] wasi-common = { path = "../", version = "0.26.0" } wasi-cap-std-sync = { path = "../cap-std-sync", version = "0.26.0" } wiggle = { path = "../../wiggle", version = "0.26.0" } -tokio = { version = "1.5.0", features = [ "rt", "fs", "time" , "io-util"] } +tokio = { version = "1.5.0", features = [ "rt", "fs", "time", "io-util", "net"] } cap-std = "0.13.7" cap-fs-ext = "0.13.7" cap-time-ext = "0.13.7" @@ -34,4 +34,4 @@ lazy_static = "1.4" [dev-dependencies] tempfile = "3.1.0" -tokio = { features = [ "macros" ] } +tokio = { version = "1.5.0", features = [ "macros" ] } diff --git a/crates/wasi-common/tokio/src/file.rs b/crates/wasi-common/tokio/src/file.rs index a47e6b4ccf..2acdbc0f4a 100644 --- a/crates/wasi-common/tokio/src/file.rs +++ b/crates/wasi-common/tokio/src/file.rs @@ -184,13 +184,47 @@ impl WasiFile for File { use unsafe_io::AsUnsafeFile; asyncify(|| self.0.as_file_view().num_ready_bytes()).await } + #[cfg(not(windows))] async fn readable(&mut self) -> Result<(), Error> { - todo!("implement this in terms of tokio::io::AsyncFd") + // The Inner impls OwnsRaw, which asserts exclusive use of the handle by the owned object. + // AsyncFd needs to wrap an owned `impl std::os::unix::io::AsRawFd`. Rather than introduce + // mutability to let it own the `Inner`, we are depending on the `&mut self` bound on this + // async method to ensure this is the only Future which can access the RawFd during the + // lifetime of the AsyncFd. + use tokio::io::{unix::AsyncFd, Interest}; + use unsafe_io::os::posish::AsRawFd; + let rawfd = self.0.as_raw_fd(); + let asyncfd = AsyncFd::with_interest(rawfd, Interest::READABLE)?; + let _ = asyncfd.readable().await?; + Ok(()) } + #[cfg(windows)] + async fn readable(&mut self) -> Result<(), Error> { + // Windows uses a rawfd based scheduler :( + Err(Error::badf()) + } + + #[cfg(not(windows))] async fn writable(&mut self) -> Result<(), Error> { - todo!("implement this in terms of tokio::io::AsyncFd") + // The Inner impls OwnsRaw, which asserts exclusive use of the handle by the owned object. + // AsyncFd needs to wrap an owned `impl std::os::unix::io::AsRawFd`. Rather than introduce + // mutability to let it own the `Inner`, we are depending on the `&mut self` bound on this + // async method to ensure this is the only Future which can access the RawFd during the + // lifetime of the AsyncFd. + use tokio::io::{unix::AsyncFd, Interest}; + use unsafe_io::os::posish::AsRawFd; + let rawfd = self.0.as_raw_fd(); + let asyncfd = AsyncFd::with_interest(rawfd, Interest::WRITABLE)?; + let _ = asyncfd.writable().await?; + Ok(()) + } + #[cfg(windows)] + async fn writable(&mut self) -> Result<(), Error> { + // Windows uses a rawfd based scheduler :( + Err(Error::badf()) } } + pub fn filetype_from(ft: &cap_std::fs::FileType) -> FileType { use cap_fs_ext::FileTypeExt; if ft.is_dir() { diff --git a/crates/wasi-common/tokio/src/sched/unix.rs b/crates/wasi-common/tokio/src/sched/unix.rs index 07a1cf0a9f..7445394840 100644 --- a/crates/wasi-common/tokio/src/sched/unix.rs +++ b/crates/wasi-common/tokio/src/sched/unix.rs @@ -1,7 +1,9 @@ use cap_std::time::Duration; use std::convert::TryInto; +use std::future::{Future, Poll as FPoll}; use std::ops::Deref; -use std::os::unix::io::{AsRawFd, RawFd}; +use std::pin::Pin; +use std::task::Context; use wasi_common::{ file::WasiFile, sched::{ @@ -11,147 +13,36 @@ use wasi_common::{ Error, ErrorExt, }; -use poll::{PollFd, PollFlags}; - pub async fn poll_oneoff<'a>(poll: &'_ Poll<'a>) -> Result<(), Error> { if poll.is_empty() { return Ok(()); } - let mut pollfds = Vec::new(); + let mut futures: Vec>>>> = Vec::new(); let timeout = poll.earliest_clock_deadline(); for s in poll.rw_subscriptions() { match s { Subscription::Read(f) => { - let raw_fd = wasi_file_raw_fd(f.file.deref()).ok_or( - Error::invalid_argument().context("read subscription fd downcast failed"), - )?; - pollfds.push(unsafe { PollFd::new(raw_fd, PollFlags::POLLIN) }); + futures.push(Box::pin(async move { + f.file.readable().await?; + f.complete(f.file.num_ready_bytes().await?, RwEventFlags::empty()); + Ok(()) + })); } Subscription::Write(f) => { - let raw_fd = wasi_file_raw_fd(f.file.deref()).ok_or( - Error::invalid_argument().context("write subscription fd downcast failed"), - )?; - pollfds.push(unsafe { PollFd::new(raw_fd, PollFlags::POLLOUT) }); + futures.push(Box::pin(async move { + f.file.writable().await?; + f.complete(0, RwEventFlags::empty()); + Ok(()) + })); } Subscription::MonotonicClock { .. } => unreachable!(), } } - let ready = loop { - let poll_timeout = if let Some(t) = timeout { - let duration = t.duration_until().unwrap_or(Duration::from_secs(0)); - (duration.as_millis() + 1) // XXX try always rounding up? - .try_into() - .map_err(|_| Error::overflow().context("poll timeout"))? - } else { - libc::c_int::max_value() - }; - tracing::debug!( - poll_timeout = tracing::field::debug(poll_timeout), - poll_fds = tracing::field::debug(&pollfds), - "poll" - ); - match poll::poll(&mut pollfds, poll_timeout) { - Ok(ready) => break ready, - Err(_) => { - let last_err = std::io::Error::last_os_error(); - if last_err.raw_os_error().unwrap() == libc::EINTR { - continue; - } else { - return Err(last_err.into()); - } - } - } - }; - if ready > 0 { - for (rwsub, pollfd) in poll.rw_subscriptions().zip(pollfds.into_iter()) { - if let Some(revents) = pollfd.revents() { - let (nbytes, rwsub) = match rwsub { - Subscription::Read(sub) => { - let ready = sub.file.num_ready_bytes().await?; - (std::cmp::max(ready, 1), sub) - } - Subscription::Write(sub) => (0, sub), - _ => unreachable!(), - }; - if revents.contains(PollFlags::POLLNVAL) { - rwsub.error(Error::badf()); - } else if revents.contains(PollFlags::POLLERR) { - rwsub.error(Error::io()); - } else if revents.contains(PollFlags::POLLHUP) { - rwsub.complete(nbytes, RwEventFlags::HANGUP); - } else { - rwsub.complete(nbytes, RwEventFlags::empty()); - }; - } - } - } else { - timeout - .expect("timed out") - .result() - .expect("timer deadline is past") - .unwrap() + // Incorrect, but lets get the type errors fixed before we write the right multiplexer here: + for f in futures { + f.await?; } Ok(()) } - -fn wasi_file_raw_fd(f: &dyn WasiFile) -> Option { - todo!() -} - -mod poll { - use bitflags::bitflags; - use std::convert::TryInto; - use std::os::unix::io::RawFd; - - bitflags! { - pub struct PollFlags: libc::c_short { - const POLLIN = libc::POLLIN; - const POLLPRI = libc::POLLPRI; - const POLLOUT = libc::POLLOUT; - const POLLRDNORM = libc::POLLRDNORM; - const POLLWRNORM = libc::POLLWRNORM; - const POLLRDBAND = libc::POLLRDBAND; - const POLLWRBAND = libc::POLLWRBAND; - const POLLERR = libc::POLLERR; - const POLLHUP = libc::POLLHUP; - const POLLNVAL = libc::POLLNVAL; - } - } - - #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] - #[repr(C)] - pub struct PollFd(libc::pollfd); - - impl PollFd { - pub unsafe fn new(fd: RawFd, events: PollFlags) -> Self { - Self(libc::pollfd { - fd, - events: events.bits(), - revents: PollFlags::empty().bits(), - }) - } - - pub fn revents(self) -> Option { - PollFlags::from_bits(self.0.revents) - } - } - - pub fn poll(fds: &mut [PollFd], timeout: libc::c_int) -> Result { - let nready = unsafe { - libc::poll( - fds.as_mut_ptr() as *mut libc::pollfd, - fds.len() as libc::nfds_t, - timeout, - ) - }; - if nready == -1 { - Err(std::io::Error::last_os_error()) - } else { - // When poll doesn't fail, its return value is a non-negative int, which will - // always be convertable to usize, so we can unwrap() here. - Ok(nready.try_into().unwrap()) - } - } -}