reuse cap-std-syncs windows scheduler without copypaste

This commit is contained in:
Pat Hickey
2021-05-06 15:45:54 -07:00
parent 3d9b98f1df
commit ff8bdc390b
4 changed files with 212 additions and 410 deletions

View File

@@ -5,112 +5,93 @@ use wasi_common::{
file::WasiFile,
sched::{
subscription::{RwEventFlags, Subscription},
Poll, WasiSched,
Poll,
},
Error, ErrorExt,
};
use poll::{PollFd, PollFlags};
pub struct SyncSched;
impl SyncSched {
pub fn new() -> Self {
SyncSched
pub async fn poll_oneoff<'a>(poll: &mut Poll<'a>) -> Result<(), Error> {
if poll.is_empty() {
return Ok(());
}
}
let mut pollfds = Vec::new();
for s in poll.rw_subscriptions() {
match s {
Subscription::Read(f) => {
let raw_fd = wasi_file_raw_fd(f.file).ok_or(
Error::invalid_argument().context("read subscription fd downcast failed"),
)?;
pollfds.push(unsafe { PollFd::new(raw_fd, PollFlags::POLLIN) });
}
#[async_trait::async_trait(?Send)]
impl WasiSched for SyncSched {
async fn poll_oneoff<'a>(&self, poll: &mut Poll<'a>) -> Result<(), Error> {
if poll.is_empty() {
return Ok(());
Subscription::Write(f) => {
let raw_fd = wasi_file_raw_fd(f.file).ok_or(
Error::invalid_argument().context("write subscription fd downcast failed"),
)?;
pollfds.push(unsafe { PollFd::new(raw_fd, PollFlags::POLLOUT) });
}
Subscription::MonotonicClock { .. } => unreachable!(),
}
let mut pollfds = Vec::new();
for s in poll.rw_subscriptions() {
match s {
Subscription::Read(f) => {
let raw_fd = wasi_file_raw_fd(f.file).ok_or(
Error::invalid_argument().context("read subscription fd downcast failed"),
)?;
pollfds.push(unsafe { PollFd::new(raw_fd, PollFlags::POLLIN) });
}
}
Subscription::Write(f) => {
let raw_fd = wasi_file_raw_fd(f.file).ok_or(
Error::invalid_argument().context("write subscription fd downcast failed"),
)?;
pollfds.push(unsafe { PollFd::new(raw_fd, PollFlags::POLLOUT) });
}
Subscription::MonotonicClock { .. } => unreachable!(),
}
}
let ready = loop {
let poll_timeout = if let Some(t) = poll.earliest_clock_deadline() {
let duration = t.duration_until().unwrap_or(Duration::from_secs(0));
(duration.as_millis() + 1) // XXX try always rounding up?
.try_into()
.map_err(|_| Error::overflow().context("poll timeout"))?
} else {
libc::c_int::max_value()
};
tracing::debug!(
poll_timeout = tracing::field::debug(poll_timeout),
poll_fds = tracing::field::debug(&pollfds),
"poll"
);
match poll::poll(&mut pollfds, poll_timeout) {
Ok(ready) => break ready,
Err(_) => {
let last_err = std::io::Error::last_os_error();
if last_err.raw_os_error().unwrap() == libc::EINTR {
continue;
} else {
return Err(last_err.into());
}
}
}
};
if ready > 0 {
for (rwsub, pollfd) in poll.rw_subscriptions().zip(pollfds.into_iter()) {
if let Some(revents) = pollfd.revents() {
let (nbytes, rwsub) = match rwsub {
Subscription::Read(sub) => {
let ready = sub.file.num_ready_bytes().await?;
(std::cmp::max(ready, 1), sub)
}
Subscription::Write(sub) => (0, sub),
_ => unreachable!(),
};
if revents.contains(PollFlags::POLLNVAL) {
rwsub.error(Error::badf());
} else if revents.contains(PollFlags::POLLERR) {
rwsub.error(Error::io());
} else if revents.contains(PollFlags::POLLHUP) {
rwsub.complete(nbytes, RwEventFlags::HANGUP);
} else {
rwsub.complete(nbytes, RwEventFlags::empty());
};
}
}
let ready = loop {
let poll_timeout = if let Some(t) = poll.earliest_clock_deadline() {
let duration = t.duration_until().unwrap_or(Duration::from_secs(0));
(duration.as_millis() + 1) // XXX try always rounding up?
.try_into()
.map_err(|_| Error::overflow().context("poll timeout"))?
} else {
poll.earliest_clock_deadline()
.expect("timed out")
.result()
.expect("timer deadline is past")
.unwrap()
libc::c_int::max_value()
};
tracing::debug!(
poll_timeout = tracing::field::debug(poll_timeout),
poll_fds = tracing::field::debug(&pollfds),
"poll"
);
match poll::poll(&mut pollfds, poll_timeout) {
Ok(ready) => break ready,
Err(_) => {
let last_err = std::io::Error::last_os_error();
if last_err.raw_os_error().unwrap() == libc::EINTR {
continue;
} else {
return Err(last_err.into());
}
}
}
Ok(())
}
async fn sched_yield(&self) -> Result<(), Error> {
std::thread::yield_now();
Ok(())
}
async fn sleep(&self, duration: Duration) -> Result<(), Error> {
std::thread::sleep(duration);
Ok(())
};
if ready > 0 {
for (rwsub, pollfd) in poll.rw_subscriptions().zip(pollfds.into_iter()) {
if let Some(revents) = pollfd.revents() {
let (nbytes, rwsub) = match rwsub {
Subscription::Read(sub) => {
let ready = sub.file.num_ready_bytes().await?;
(std::cmp::max(ready, 1), sub)
}
Subscription::Write(sub) => (0, sub),
_ => unreachable!(),
};
if revents.contains(PollFlags::POLLNVAL) {
rwsub.error(Error::badf());
} else if revents.contains(PollFlags::POLLERR) {
rwsub.error(Error::io());
} else if revents.contains(PollFlags::POLLHUP) {
rwsub.complete(nbytes, RwEventFlags::HANGUP);
} else {
rwsub.complete(nbytes, RwEventFlags::empty());
};
}
}
} else {
poll.earliest_clock_deadline()
.expect("timed out")
.result()
.expect("timer deadline is past")
.unwrap()
}
Ok(())
}
fn wasi_file_raw_fd(f: &dyn WasiFile) -> Option<RawFd> {