diff --git a/crates/wasi-common/src/sched.rs b/crates/wasi-common/src/sched.rs index 3434ef62dc..b9a8062a4b 100644 --- a/crates/wasi-common/src/sched.rs +++ b/crates/wasi-common/src/sched.rs @@ -2,7 +2,7 @@ use crate::clocks::WasiMonotonicClock; use crate::file::WasiFile; use crate::Error; use cap_std::time::Instant; -use std::cell::Ref; +use std::cell::RefMut; pub mod subscription; pub use cap_std::time::Duration; @@ -52,11 +52,11 @@ impl<'a> Poll<'a> { ud, )); } - pub fn subscribe_read(&mut self, file: Ref<'a, dyn WasiFile>, ud: Userdata) { + pub fn subscribe_read(&mut self, file: RefMut<'a, dyn WasiFile>, ud: Userdata) { self.subs .push((Subscription::Read(RwSubscription::new(file)), ud)); } - pub fn subscribe_write(&mut self, file: Ref<'a, dyn WasiFile>, ud: Userdata) { + pub fn subscribe_write(&mut self, file: RefMut<'a, dyn WasiFile>, ud: Userdata) { self.subs .push((Subscription::Write(RwSubscription::new(file)), ud)); } diff --git a/crates/wasi-common/src/sched/subscription.rs b/crates/wasi-common/src/sched/subscription.rs index 799cfc665f..3347d3a826 100644 --- a/crates/wasi-common/src/sched/subscription.rs +++ b/crates/wasi-common/src/sched/subscription.rs @@ -3,7 +3,7 @@ use crate::file::WasiFile; use crate::Error; use bitflags::bitflags; use cap_std::time::{Duration, Instant}; -use std::cell::{Cell, Ref}; +use std::cell::{Cell, RefMut}; bitflags! { pub struct RwEventFlags: u32 { @@ -12,12 +12,12 @@ bitflags! { } pub struct RwSubscription<'a> { - pub file: Ref<'a, dyn WasiFile>, + pub file: RefMut<'a, dyn WasiFile>, status: Cell>>, } impl<'a> RwSubscription<'a> { - pub fn new(file: Ref<'a, dyn WasiFile>) -> Self { + pub fn new(file: RefMut<'a, dyn WasiFile>) -> Self { Self { file, status: Cell::new(None), diff --git a/crates/wasi-common/src/snapshots/preview_0.rs b/crates/wasi-common/src/snapshots/preview_0.rs index 362789bc35..775962afef 100644 --- a/crates/wasi-common/src/snapshots/preview_0.rs +++ b/crates/wasi-common/src/snapshots/preview_0.rs @@ -1,4 +1,4 @@ -use crate::file::{FileCaps, FileEntryExt, TableFileExt}; +use crate::file::{FileCaps, FileEntryExt, FileEntryMutExt, TableFileExt}; use crate::sched::{ subscription::{RwEventFlags, SubscriptionResult}, Poll, @@ -7,6 +7,7 @@ use crate::snapshots::preview_1::types as snapshot1_types; use crate::snapshots::preview_1::wasi_snapshot_preview1::WasiSnapshotPreview1 as Snapshot1; use crate::{Error, ErrorExt, WasiCtx}; use cap_std::time::Duration; +use std::collections::HashSet; use std::convert::{TryFrom, TryInto}; use std::io::{IoSlice, IoSliceMut}; use std::ops::Deref; @@ -778,6 +779,7 @@ impl wasi_unstable::WasiUnstable for WasiCtx { } let table = self.table(); + let mut subscribed_fds = HashSet::new(); let mut poll = Poll::new(); let subs = subs.as_array(nsubscriptions); @@ -816,15 +818,27 @@ impl wasi_unstable::WasiUnstable for WasiCtx { }, types::SubscriptionU::FdRead(readsub) => { let fd = readsub.file_descriptor; + if subscribed_fds.contains(&fd) { + Err(Error::invalid_argument() + .context("Fd can be subscribed to at most once per poll_oneoff"))?; + } else { + subscribed_fds.insert(fd); + } let file = table - .get_file(u32::from(fd))? + .get_file_mut(u32::from(fd))? .get_cap(FileCaps::POLL_READWRITE)?; poll.subscribe_read(file, sub.userdata.into()); } types::SubscriptionU::FdWrite(writesub) => { let fd = writesub.file_descriptor; + if subscribed_fds.contains(&fd) { + Err(Error::invalid_argument() + .context("Fd can be subscribed to at most once per poll_oneoff"))?; + } else { + subscribed_fds.insert(fd); + } let file = table - .get_file(u32::from(fd))? + .get_file_mut(u32::from(fd))? .get_cap(FileCaps::POLL_READWRITE)?; poll.subscribe_write(file, sub.userdata.into()); } diff --git a/crates/wasi-common/src/snapshots/preview_1.rs b/crates/wasi-common/src/snapshots/preview_1.rs index 4d04e6e047..04ba4313d0 100644 --- a/crates/wasi-common/src/snapshots/preview_1.rs +++ b/crates/wasi-common/src/snapshots/preview_1.rs @@ -13,6 +13,7 @@ use crate::{ use anyhow::Context; use cap_std::time::{Duration, SystemClock}; use std::cell::{Ref, RefMut}; +use std::collections::HashSet; use std::convert::{TryFrom, TryInto}; use std::io::{IoSlice, IoSliceMut}; use std::ops::{Deref, DerefMut}; @@ -970,6 +971,7 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { } let table = self.table(); + let mut subscribed_fds = HashSet::new(); let mut poll = Poll::new(); let subs = subs.as_array(nsubscriptions); @@ -1008,15 +1010,27 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { }, types::SubscriptionU::FdRead(readsub) => { let fd = readsub.file_descriptor; + if subscribed_fds.contains(&fd) { + return Err(Error::invalid_argument() + .context("Fd can be subscribed to at most once per poll_oneoff")); + } else { + subscribed_fds.insert(fd); + } let file = table - .get_file(u32::from(fd))? + .get_file_mut(u32::from(fd))? .get_cap(FileCaps::POLL_READWRITE)?; poll.subscribe_read(file, sub.userdata.into()); } types::SubscriptionU::FdWrite(writesub) => { let fd = writesub.file_descriptor; + if subscribed_fds.contains(&fd) { + return Err(Error::invalid_argument() + .context("Fd can be subscribed to at most once per poll_oneoff")); + } else { + subscribed_fds.insert(fd); + } let file = table - .get_file(u32::from(fd))? + .get_file_mut(u32::from(fd))? .get_cap(FileCaps::POLL_READWRITE)?; poll.subscribe_write(file, sub.userdata.into()); }