use cap_std::time::Duration; use std::convert::TryInto; use std::os::unix::io::{AsRawFd, RawFd}; use wasi_common::{ file::WasiFile, sched::{ subscription::{RwEventFlags, Subscription}, Poll, }, Error, ErrorExt, }; use poll::{PollFd, PollFlags}; pub async fn poll_oneoff<'a>(poll: &mut Poll<'a>) -> Result<(), Error> { if poll.is_empty() { return Ok(()); } let mut pollfds = Vec::new(); for s in poll.rw_subscriptions() { match s { Subscription::Read(f) => { let raw_fd = wasi_file_raw_fd(f.file).ok_or( Error::invalid_argument().context("read subscription fd downcast failed"), )?; pollfds.push(unsafe { PollFd::new(raw_fd, PollFlags::POLLIN) }); } Subscription::Write(f) => { let raw_fd = wasi_file_raw_fd(f.file).ok_or( Error::invalid_argument().context("write subscription fd downcast failed"), )?; pollfds.push(unsafe { PollFd::new(raw_fd, PollFlags::POLLOUT) }); } Subscription::MonotonicClock { .. } => unreachable!(), } } let ready = loop { let poll_timeout = if let Some(t) = poll.earliest_clock_deadline() { let duration = t.duration_until().unwrap_or(Duration::from_secs(0)); (duration.as_millis() + 1) // XXX try always rounding up? .try_into() .map_err(|_| Error::overflow().context("poll timeout"))? } else { libc::c_int::max_value() }; tracing::debug!( poll_timeout = tracing::field::debug(poll_timeout), poll_fds = tracing::field::debug(&pollfds), "poll" ); match poll::poll(&mut pollfds, poll_timeout) { Ok(ready) => break ready, Err(_) => { let last_err = std::io::Error::last_os_error(); if last_err.raw_os_error().unwrap() == libc::EINTR { continue; } else { return Err(last_err.into()); } } } }; if ready > 0 { for (rwsub, pollfd) in poll.rw_subscriptions().zip(pollfds.into_iter()) { if let Some(revents) = pollfd.revents() { let (nbytes, rwsub) = match rwsub { Subscription::Read(sub) => { let ready = sub.file.num_ready_bytes().await?; (std::cmp::max(ready, 1), sub) } Subscription::Write(sub) => (0, sub), _ => unreachable!(), }; if revents.contains(PollFlags::POLLNVAL) { rwsub.error(Error::badf()); } else if revents.contains(PollFlags::POLLERR) { rwsub.error(Error::io()); } else if revents.contains(PollFlags::POLLHUP) { rwsub.complete(nbytes, RwEventFlags::HANGUP); } else { rwsub.complete(nbytes, RwEventFlags::empty()); }; } } } else { poll.earliest_clock_deadline() .expect("timed out") .result() .expect("timer deadline is past") .unwrap() } Ok(()) } fn wasi_file_raw_fd(f: &dyn WasiFile) -> Option { let a = f.as_any(); if a.is::() { Some(a.downcast_ref::().unwrap().as_raw_fd()) } else if a.is::() { Some(a.downcast_ref::().unwrap().as_raw_fd()) } else if a.is::() { Some( a.downcast_ref::() .unwrap() .as_raw_fd(), ) } else if a.is::() { Some( a.downcast_ref::() .unwrap() .as_raw_fd(), ) } else { None } } mod poll { use bitflags::bitflags; use std::convert::TryInto; use std::os::unix::io::RawFd; bitflags! { pub struct PollFlags: libc::c_short { const POLLIN = libc::POLLIN; const POLLPRI = libc::POLLPRI; const POLLOUT = libc::POLLOUT; const POLLRDNORM = libc::POLLRDNORM; const POLLWRNORM = libc::POLLWRNORM; const POLLRDBAND = libc::POLLRDBAND; const POLLWRBAND = libc::POLLWRBAND; const POLLERR = libc::POLLERR; const POLLHUP = libc::POLLHUP; const POLLNVAL = libc::POLLNVAL; } } #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] #[repr(C)] pub struct PollFd(libc::pollfd); impl PollFd { pub unsafe fn new(fd: RawFd, events: PollFlags) -> Self { Self(libc::pollfd { fd, events: events.bits(), revents: PollFlags::empty().bits(), }) } pub fn revents(self) -> Option { PollFlags::from_bits(self.0.revents) } } pub fn poll(fds: &mut [PollFd], timeout: libc::c_int) -> Result { let nready = unsafe { libc::poll( fds.as_mut_ptr() as *mut libc::pollfd, fds.len() as libc::nfds_t, timeout, ) }; if nready == -1 { Err(std::io::Error::last_os_error()) } else { // When poll doesn't fail, its return value is a non-negative int, which will // always be convertable to usize, so we can unwrap() here. Ok(nready.try_into().unwrap()) } } }