diff --git a/crates/wasi-http/src/http_impl.rs b/crates/wasi-http/src/http_impl.rs index 69c94d8489..5b0e3862d9 100644 --- a/crates/wasi-http/src/http_impl.rs +++ b/crates/wasi-http/src/http_impl.rs @@ -1,5 +1,5 @@ use crate::r#struct::ActiveResponse; -pub use crate::r#struct::WasiHttp; +use crate::r#struct::{Stream, WasiHttp}; use crate::types::{RequestOptions, Scheme}; #[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))] use anyhow::anyhow; @@ -183,8 +183,10 @@ impl WasiHttp { let body = Full::::new( self.streams .get(&request.body) - .unwrap_or(&Bytes::new()) - .clone(), + .unwrap_or(&Stream::default()) + .data + .clone() + .freeze(), ); let t = timeout(first_bytes_timeout, sender.send_request(call.body(body)?)).await?; let mut res = t?; @@ -222,7 +224,7 @@ impl WasiHttp { } response.body = self.streams_id_base; self.streams_id_base = self.streams_id_base + 1; - self.streams.insert(response.body, buf.freeze()); + self.streams.insert(response.body, buf.freeze().into()); self.responses.insert(response_id, response); Ok(response_id) } diff --git a/crates/wasi-http/src/streams_impl.rs b/crates/wasi-http/src/streams_impl.rs index 6b88fae762..7bc0405c14 100644 --- a/crates/wasi-http/src/streams_impl.rs +++ b/crates/wasi-http/src/streams_impl.rs @@ -2,7 +2,6 @@ use crate::poll::Pollable; use crate::streams::{InputStream, OutputStream, StreamError}; use crate::WasiHttp; use anyhow::{anyhow, bail}; -use bytes::BufMut; use std::vec::Vec; impl crate::streams::Host for WasiHttp { @@ -11,10 +10,14 @@ impl crate::streams::Host for WasiHttp { stream: InputStream, len: u64, ) -> wasmtime::Result, bool), StreamError>> { - let s = self + let st = self .streams .get_mut(&stream) .ok_or_else(|| anyhow!("stream not found: {stream}"))?; + if st.closed { + bail!("stream is dropped!"); + } + let s = &mut st.data; if len == 0 { Ok(Ok((bytes::Bytes::new().to_vec(), s.len() > 0))) } else if s.len() > len.try_into()? { @@ -31,10 +34,14 @@ impl crate::streams::Host for WasiHttp { stream: InputStream, len: u64, ) -> wasmtime::Result> { - let s = self + let st = self .streams .get_mut(&stream) .ok_or_else(|| anyhow!("stream not found: {stream}"))?; + if st.closed { + bail!("stream is dropped!"); + } + let s = &mut st.data; if len == 0 { Ok(Ok((0, s.len() > 0))) } else if s.len() > len.try_into()? { @@ -52,7 +59,11 @@ impl crate::streams::Host for WasiHttp { } fn drop_input_stream(&mut self, stream: InputStream) -> wasmtime::Result<()> { - self.streams.remove(&stream); + let st = self + .streams + .get_mut(&stream) + .ok_or_else(|| anyhow!("stream not found: {stream}"))?; + st.closed = true; Ok(()) } @@ -61,18 +72,13 @@ impl crate::streams::Host for WasiHttp { this: OutputStream, buf: Vec, ) -> wasmtime::Result> { - match self.streams.get(&this) { - Some(data) => { - let mut new = bytes::BytesMut::with_capacity(data.len() + buf.len()); - new.put(data.clone()); - new.put(bytes::Bytes::from(buf.clone())); - self.streams.insert(this, new.freeze()); - } - None => { - self.streams.insert(this, bytes::Bytes::from(buf.clone())); - } + let len = buf.len(); + let st = self.streams.entry(this).or_default(); + if st.closed { + bail!("cannot write to closed stream"); } - Ok(Ok(buf.len().try_into()?)) + st.data.extend_from_slice(buf.as_slice()); + Ok(Ok(len.try_into()?)) } fn write_zeroes( @@ -111,7 +117,11 @@ impl crate::streams::Host for WasiHttp { } fn drop_output_stream(&mut self, stream: OutputStream) -> wasmtime::Result<()> { - self.streams.remove(&stream); + let st = self + .streams + .get_mut(&stream) + .ok_or_else(|| anyhow!("stream not found: {stream}"))?; + st.closed = true; Ok(()) } } diff --git a/crates/wasi-http/src/struct.rs b/crates/wasi-http/src/struct.rs index 0cb1245048..574be6c8e5 100644 --- a/crates/wasi-http/src/struct.rs +++ b/crates/wasi-http/src/struct.rs @@ -1,7 +1,13 @@ use crate::types::{Method, Scheme}; -use bytes::Bytes; +use bytes::{BufMut, Bytes, BytesMut}; use std::collections::HashMap; +#[derive(Clone, Default)] +pub struct Stream { + pub closed: bool, + pub data: BytesMut, +} + #[derive(Clone)] pub struct WasiHttp { pub request_id_base: u32, @@ -11,7 +17,7 @@ pub struct WasiHttp { pub requests: HashMap, pub responses: HashMap, pub fields: HashMap>>, - pub streams: HashMap, + pub streams: HashMap, } #[derive(Clone)] @@ -66,6 +72,23 @@ impl ActiveResponse { } } +impl Stream { + pub fn new() -> Self { + Self::default() + } +} + +impl From for Stream { + fn from(bytes: Bytes) -> Self { + let mut buf = BytesMut::with_capacity(bytes.len()); + buf.put(bytes); + Self { + closed: false, + data: buf, + } + } +} + impl WasiHttp { pub fn new() -> Self { Self { diff --git a/crates/wasi-http/src/types_impl.rs b/crates/wasi-http/src/types_impl.rs index 721dc4cb9e..e9079de7de 100644 --- a/crates/wasi-http/src/types_impl.rs +++ b/crates/wasi-http/src/types_impl.rs @@ -1,5 +1,5 @@ use crate::poll::Pollable; -use crate::r#struct::ActiveRequest; +use crate::r#struct::{ActiveRequest, Stream}; use crate::types::{ Error, Fields, FutureIncomingResponse, Headers, IncomingRequest, IncomingResponse, IncomingStream, Method, OutgoingRequest, OutgoingResponse, OutgoingStream, ResponseOutparam, @@ -7,7 +7,7 @@ use crate::types::{ }; use crate::WasiHttp; use anyhow::{anyhow, bail}; -use std::collections::HashMap; +use std::collections::{hash_map::Entry, HashMap}; impl crate::types::Host for WasiHttp { fn drop_fields(&mut self, fields: Fields) -> wasmtime::Result<()> { @@ -123,7 +123,10 @@ impl crate::types::Host for WasiHttp { bail!("unimplemented: drop_incoming_request") } fn drop_outgoing_request(&mut self, request: OutgoingRequest) -> wasmtime::Result<()> { - self.requests.remove(&request); + if let Entry::Occupied(e) = self.requests.entry(request) { + let r = e.remove(); + self.streams.remove(&r.body); + } Ok(()) } fn incoming_request_method(&mut self, _request: IncomingRequest) -> wasmtime::Result { @@ -192,6 +195,7 @@ impl crate::types::Host for WasiHttp { if req.body == 0 { req.body = self.streams_id_base; self.streams_id_base = self.streams_id_base + 1; + self.streams.insert(req.body, Stream::default()); } Ok(Ok(req.body)) } @@ -206,7 +210,10 @@ impl crate::types::Host for WasiHttp { bail!("unimplemented: set_response_outparam") } fn drop_incoming_response(&mut self, response: IncomingResponse) -> wasmtime::Result<()> { - self.responses.remove(&response); + if let Entry::Occupied(e) = self.responses.entry(response) { + let r = e.remove(); + self.streams.remove(&r.body); + } Ok(()) } fn drop_outgoing_response(&mut self, _response: OutgoingResponse) -> wasmtime::Result<()> {