From 718af0bb33c2ec644faf5bbf0ac7eb35769d9f16 Mon Sep 17 00:00:00 2001 From: Justin Karneges Date: Thu, 18 Jun 2026 09:13:24 -0700 Subject: [PATCH] connmgr: add ability to send large response headers --- src/connmgr/connection.rs | 184 +++++++++++------ src/connmgr/server.rs | 6 +- src/connmgr/track.rs | 7 + src/core/http1/error.rs | 2 +- src/core/http1/server.rs | 409 +++++++++++++++++++++----------------- src/core/http1/util.rs | 15 -- 6 files changed, 361 insertions(+), 262 deletions(-) diff --git a/src/connmgr/connection.rs b/src/connmgr/connection.rs index 109b2b192..3283375dd 100644 --- a/src/connmgr/connection.rs +++ b/src/connmgr/connection.rs @@ -1348,10 +1348,10 @@ async fn send_error_response( 400 } - Error::CoreHttp(CoreHttpError::ResponseTooLarge(limit)) => { + Error::CoreHttp(CoreHttpError::ResponseTooManyHeaders(limit)) => { writeln!( &mut body, - "Response header size exceeded limit of {} bytes.", + "The number of response headers exceeded the limit of {}.", limit )?; @@ -1568,14 +1568,22 @@ async fn server_req_read_header_and_body( Ok(Some(result?)) } -struct ReqRespond<'buf, 'st, R: AsyncRead, W: AsyncWrite> { - header: server::ResponseHeader<'buf, 'st, R, W>, +struct ReqRespond<'buf, 'st, 'headers, R: AsyncRead, W: AsyncWrite> { + header: server::ResponseHeader<'buf, 'st, 'headers, R, W>, prepare_body: server::ResponsePrepareBody<'buf, 'st, R, W>, } // Consumes resp if successful #[allow(clippy::too_many_arguments)] -async fn server_req_respond<'buf, 'st, R: AsyncRead, W: AsyncWrite>( +async fn server_req_respond< + 'buf, + 'st, + 'headers, + 'resp: 'headers, + 'tr, + R: AsyncRead, + W: AsyncWrite, +>( id: &str, req: server::Request, resp: &mut Option>, @@ -1585,8 +1593,10 @@ async fn server_req_respond<'buf, 'st, R: AsyncRead, W: AsyncWrite>( body_buf: &mut ContiguousBuffer, packet_buf: &RefCell>, zsender: &AsyncLocalSender, - zreceiver: &TrackedAsyncLocalReceiver<'_, (memorypool::Rc, usize)>, -) -> Result>, Error> { + zreceiver: &TrackedAsyncLocalReceiver<'tr, (memorypool::Rc, usize)>, + headers_scratch: &'headers mut ArrayVec, HEADERS_MAX>, + zresp_scratch: &'resp mut Option>, +) -> Result>, Error> { let msg = { let req_header = req.recv_header(resp.as_mut().unwrap()); @@ -1635,6 +1645,9 @@ async fn server_req_respond<'buf, 'st, R: AsyncRead, W: AsyncWrite>( } }; + // Remove tracking so the message can be retained + let zresp = zresp_scratch.insert(zresp.into_inner()); + let (header, prepare_body) = { let zresp = zresp.get(); @@ -1649,7 +1662,7 @@ async fn server_req_respond<'buf, 'st, R: AsyncRead, W: AsyncWrite>( // Send response header - let mut headers = ArrayVec::::new(); + let headers = headers_scratch; for h in rdata.headers.iter() { if headers.remaining_capacity() == 0 { @@ -1667,7 +1680,7 @@ async fn server_req_respond<'buf, 'st, R: AsyncRead, W: AsyncWrite>( let (header, prepare_body) = match resp_take.prepare_header( rdata.code, rdata.reason, - &headers, + headers, http1::BodySize::Known(rdata.body.len()), resp_state, ) { @@ -1705,43 +1718,50 @@ async fn server_req_handler( let mut resp_state = server::ResponseState::default(); - let r = { - let (req, resp) = server::Request::new(io_split(&stream), buf1, buf2); - let mut resp = Some(resp); + let resp_body = { + let mut zresp_scratch = None; + let mut headers_scratch = ArrayVec::new(); - let ret = match server_req_respond( - id, - req, - &mut resp, - &mut resp_state, - peer_addr, - secure, - body_buf, - packet_buf, - zsender, - zreceiver, - ) - .await - { - Ok(Some(ret)) => ret, - Ok(None) => return Ok(false), // No request - Err(e) => { - // On error, resp is not consumed, so we can use it - send_error_response(resp.take().unwrap(), zreceiver, &e).await?; + let r = { + let (req, resp) = server::Request::new(io_split(&stream), buf1, buf2); + let mut resp = Some(resp); - return Err(e); - } - }; + let ret = match server_req_respond( + id, + req, + &mut resp, + &mut resp_state, + peer_addr, + secure, + body_buf, + packet_buf, + zsender, + zreceiver, + &mut headers_scratch, + &mut zresp_scratch, + ) + .await + { + Ok(Some(ret)) => ret, + Ok(None) => return Ok(false), // No request + Err(e) => { + // On error, resp is not consumed, so we can use it + send_error_response(resp.take().unwrap(), zreceiver, &e).await?; - assert!(resp.is_none()); + return Err(e); + } + }; - ret - }; + assert!(resp.is_none()); - // ABR: discard_while - let header_sent = discard_while(zreceiver, pin!(r.header.send())).await?; + ret + }; - let resp_body = header_sent.start_body(r.prepare_body); + // ABR: discard_while + let header_sent = discard_while(zreceiver, pin!(r.header.send())).await?; + + header_sent.start_body(r.prepare_body) + }; // Send response body @@ -3206,7 +3226,6 @@ where } struct WsReqData { - accept: ArrayString, deflate_config: Option<(websocket::PerMessageDeflateConfig, usize)>, } @@ -3221,6 +3240,7 @@ fn server_stream_process_req_header( instance_id: &str, shared: &StreamSharedData, recv_buf_size: usize, + ws_accept: &mut ArrayString, ) -> Result<(zmq::Message, Option), Error> { let mut websocket = false; let mut ws_version = None; @@ -3299,13 +3319,12 @@ fn server_stream_process_req_header( ); let ws_req_data: Option = if websocket { - let accept = match validate_ws_request(req, ws_version, ws_key) { + *ws_accept = match validate_ws_request(req, ws_version, ws_key) { Ok(s) => s, Err(_) => return Err(Error::InvalidWebSocketRequest), }; Some(WsReqData { - accept, deflate_config: ws_deflate_config, }) } else { @@ -3364,6 +3383,7 @@ async fn server_stream_read_header<'a: 'b, 'b, R: AsyncRead, W: AsyncWrite>( zreceiver: &TrackedAsyncLocalReceiver<'_, (memorypool::Rc, usize)>, shared: &StreamSharedData, recv_buf_size: usize, + ws_accept: &mut ArrayString, ) -> Result< Option<( zmq::Message, @@ -3401,6 +3421,7 @@ async fn server_stream_read_header<'a: 'b, 'b, R: AsyncRead, W: AsyncWrite>( instance_id, shared, recv_buf_size, + ws_accept, ); let body_size = req_ref.body_size; @@ -3415,30 +3436,39 @@ async fn server_stream_read_header<'a: 'b, 'b, R: AsyncRead, W: AsyncWrite>( Ok(Some((msg, body_size, ws_req_data, req_body))) } -struct StreamRespondProceed<'buf, 'st, 'zs, 'tr, R: AsyncRead, W: AsyncWrite, R2> { - header: server::ResponseHeader<'buf, 'st, R, W>, +struct StreamRespondProceed<'buf, 'st, 'headers, 'zs, 'tr, R: AsyncRead, W: AsyncWrite, R2> { + header: server::ResponseHeader<'buf, 'st, 'headers, R, W>, prepare_body: server::ResponsePrepareBody<'buf, 'st, R, W>, zsess_in: ZhttpStreamSessionIn<'zs, 'tr, R2>, ws_config: Option>, } -struct StreamRespondWebSocketRejected<'buf, 'st, R: AsyncRead, W: AsyncWrite> { - header: server::ResponseHeader<'buf, 'st, R, W>, +struct StreamRespondWebSocketRejected<'buf, 'st, 'headers, R: AsyncRead, W: AsyncWrite> { + header: server::ResponseHeader<'buf, 'st, 'headers, R, W>, prepare_body: server::ResponsePrepareBody<'buf, 'st, R, W>, } -enum StreamRespond<'buf, 'st, 'zs, 'tr, R: AsyncRead, W: AsyncWrite, R2> { - Proceed(StreamRespondProceed<'buf, 'st, 'zs, 'tr, R, W, R2>), - WebSocketRejected(StreamRespondWebSocketRejected<'buf, 'st, R, W>), +enum StreamRespond<'buf, 'st, 'headers, 'zs, 'tr, R: AsyncRead, W: AsyncWrite, R2> { + Proceed(StreamRespondProceed<'buf, 'st, 'headers, 'zs, 'tr, R, W, R2>), + WebSocketRejected(StreamRespondWebSocketRejected<'buf, 'st, 'headers, R, W>), +} + +#[derive(Default)] +struct StreamRespondScratch { + zresp: Option>, + ws_ext: ArrayVec, + ws_accept: ArrayString, } // Consumes resp if successful #[allow(clippy::too_many_arguments)] -async fn server_stream_respond<'buf, 'st, 'zs, 'tr, R, W, R1, R2>( +async fn server_stream_respond<'buf, 'st, 'headers, 'resp: 'headers, 'zs, 'tr, R, W, R1, R2>( id: &'zs str, req: server::Request, resp: &mut Option>, resp_state: &'st mut server::ResponseState<'buf, R, W>, + resp_scratch: &'resp mut StreamRespondScratch, + headers_scratch: &'headers mut ArrayVec, HEADERS_MAX>, peer_addr: Option<&SocketAddr>, secure: bool, send_buf_size: usize, @@ -3457,7 +3487,7 @@ async fn server_stream_respond<'buf, 'st, 'zs, 'tr, R, W, R1, R2>( shared: &'zs StreamSharedData, refresh_stream_timeout: &R1, refresh_session_timeout: &'zs R2, -) -> Result>, Error> +) -> Result>, Error> where R: AsyncRead, W: AsyncWrite, @@ -3480,6 +3510,7 @@ where zreceiver, shared, recv_buf_size, + &mut resp_scratch.ws_accept, ) .await?; @@ -3556,20 +3587,38 @@ where // Determine how to respond let rdata = match &zresp.get().ptype { - zhttppacket::ResponsePacket::Data(rdata) => rdata, + zhttppacket::ResponsePacket::Data(_) => { + // Remove tracking so the message can be retained + let zresp = resp_scratch.zresp.insert(zresp.into_inner()); + + // Borrow again + match &zresp.get().ptype { + zhttppacket::ResponsePacket::Data(rdata) => rdata, + _ => unreachable!(), // We confirmed the type above + } + } zhttppacket::ResponsePacket::Error(edata) => { if ws_req_data.is_some() && edata.condition == "rejected" { // Send websocket rejection + // Remove tracking so the message can be retained + let zresp = resp_scratch.zresp.insert(zresp.into_inner()); + + // Borrow again + let edata = match &zresp.get().ptype { + zhttppacket::ResponsePacket::Error(edata) => edata, + _ => unreachable!(), // We confirmed the type above + }; + let rdata = edata.rejected_info.as_ref().unwrap(); if rdata.body.len() > recv_buf_size { return Err(Error::WebSocketRejectionTooLarge(recv_buf_size)); } - let (header, mut prepare_body) = { - let mut headers = ArrayVec::::new(); + let headers = headers_scratch; + let (header, mut prepare_body) = { for h in rdata.headers.iter() { // Don't send these headers if h.name.eq_ignore_ascii_case("Upgrade") @@ -3595,7 +3644,7 @@ where match resp_take.prepare_header( rdata.code, rdata.reason, - &headers, + headers, http1::BodySize::Known(rdata.body.len()), resp_state, ) { @@ -3642,11 +3691,10 @@ where // Send response header let (header, mut prepare_body) = { - let mut ws_ext = ArrayVec::::new(); - let mut headers = ArrayVec::::new(); - let mut body_size = http1::BodySize::Unknown; + let headers = headers_scratch; + for h in rdata.headers.iter() { if ws_req_data.is_some() { // Don't send these headers @@ -3684,8 +3732,10 @@ where body_size = http1::BodySize::Known(rdata.body.len()); } + let ws_ext = &mut resp_scratch.ws_ext; + if let Some(ws_req_data) = &ws_req_data { - let accept_data = &ws_req_data.accept; + let accept_data = &mut resp_scratch.ws_accept; if headers.remaining_capacity() < 4 { return Err(Error::BadMessage); @@ -3707,20 +3757,20 @@ where }); if let Some((config, _)) = &ws_req_data.deflate_config { - if write_ws_ext_header_value(config, &mut ws_ext).is_err() { + if write_ws_ext_header_value(config, ws_ext).is_err() { return Err(Error::Compression); } headers.push(http1::Header { name: "Sec-WebSocket-Extensions", - value: ws_ext.as_ref(), + value: ws_ext, }); } } let mut resp_take = resp.take().unwrap(); - match resp_take.prepare_header(rdata.code, rdata.reason, &headers, body_size, resp_state) { + match resp_take.prepare_header(rdata.code, rdata.reason, headers, body_size, resp_state) { Ok(ret) => ret, Err(e) => { *resp = Some(resp_take); @@ -3790,6 +3840,8 @@ where let zsess_out = ZhttpStreamSessionOut::new(instance_id, id, packet_buf, zsender_stream, shared); let mut resp_state = server::ResponseState::default(); + let mut resp_scratch = StreamRespondScratch::default(); + let mut headers_scratch = ArrayVec::new(); let respond = { let (req, resp) = server::Request::new(io_split(&stream), buf1, buf2); @@ -3800,6 +3852,8 @@ where req, &mut resp, &mut resp_state, + &mut resp_scratch, + &mut headers_scratch, peer_addr, secure, send_buf_size, @@ -8270,8 +8324,8 @@ mod tests { fn server_stream_expand_write_buffer() { let reactor = Reactor::new(100); - let scratch_mem = memorypool::RcMemoryPool::new(1); - let resp_mem = memorypool::RcMemoryPool::new(1); + let scratch_mem = memorypool::RcMemoryPool::new(2); + let resp_mem = memorypool::RcMemoryPool::new(2); let sock = Rc::new(RefCell::new(FakeSock::new())); diff --git a/src/connmgr/server.rs b/src/connmgr/server.rs index 44d4209f5..6aa1624ff 100644 --- a/src/connmgr/server.rs +++ b/src/connmgr/server.rs @@ -3125,9 +3125,9 @@ pub mod tests { #[cfg(debug_assertions)] #[test] fn test_task_sizes() { - // Sizes in debug mode at commit 4c1b0bb177314051405ef5be3cde023e9d1ad635 - const REQ_TASK_SIZE_BASE: usize = 5824; - const STREAM_TASK_SIZE_BASE: usize = 7760; + // Sizes in debug mode at commit TBD + const REQ_TASK_SIZE_BASE: usize = 8000; + const STREAM_TASK_SIZE_BASE: usize = 10064; // Cause tests to fail if sizes grow too much const GROWTH_LIMIT: usize = 1000; diff --git a/src/connmgr/track.rs b/src/connmgr/track.rs index a2e3caf09..1130cf565 100644 --- a/src/connmgr/track.rs +++ b/src/connmgr/track.rs @@ -55,6 +55,13 @@ impl<'a, T> Track<'a, T> { inner: Some(TrackInner { value, active }), } } + + pub fn into_inner(mut self) -> T { + let inner = self.inner.take().unwrap(); + inner.active.set(false); + + inner.value + } } impl<'a, A, B> Track<'a, (A, B)> { diff --git a/src/core/http1/error.rs b/src/core/http1/error.rs index 43f77eb76..e07b9c207 100644 --- a/src/core/http1/error.rs +++ b/src/core/http1/error.rs @@ -22,7 +22,7 @@ pub enum Error { Io(io::Error), Protocol(protocol::Error), RequestTooLarge(usize), - ResponseTooLarge(usize), + ResponseTooManyHeaders(usize), ResponseDuringContinue, FurtherInputNotAllowed, BufferExceeded, diff --git a/src/core/http1/server.rs b/src/core/http1/server.rs index c8287db96..17aae4c02 100644 --- a/src/core/http1/server.rs +++ b/src/core/http1/server.rs @@ -15,7 +15,7 @@ * limitations under the License. */ -use crate::core::buffer::{Buffer, BufferBudget, ContiguousBuffer, VecRingBuffer, BUFFER_BUFS_MAX}; +use crate::core::buffer::{Buffer, BufferBudget, VecRingBuffer, BUFFER_BUFS_MAX}; use crate::core::http1::error::Error; use crate::core::http1::protocol::{self, BodySize, Header, ParseScratch, ParseStatus}; use crate::core::http1::util::*; @@ -407,16 +407,16 @@ impl<'a, R: AsyncRead, W: AsyncWrite> Response<'a, R, W> { } #[allow(clippy::type_complexity)] - pub fn prepare_header<'b>( + pub fn prepare_header<'b, 'resp>( &mut self, code: u16, - reason: &str, - headers: &[Header<'_>], + reason: &'resp str, + headers: &'resp [Header<'resp>], body_size: BodySize, state: &'b mut ResponseState<'a, R, W>, ) -> Result< ( - ResponseHeader<'a, 'b, R, W>, + ResponseHeader<'a, 'b, 'resp, R, W>, ResponsePrepareBody<'a, 'b, R, W>, ), Error, @@ -430,31 +430,12 @@ impl<'a, R: AsyncRead, W: AsyncWrite> Response<'a, R, W> { inner.protocol.skip_recv_request(); } - inner.wbuf.clear(); - let size_limit = inner.wbuf.capacity(); - - let header_size = { - let mut buf = io::Cursor::new(inner.wbuf.write_buf()); - - match inner - .protocol - .send_response(&mut buf, code, reason, headers, body_size, 0) - { - Ok(None) => {} - Ok(Some(_)) | Err(_) => { - // Partial write, or error due to input being too large - - // Enable prepare_header to be called again - inner.wbuf.clear(); - - return Err(Error::ResponseTooLarge(size_limit)); - } - } - - buf.position() as usize - }; + // Validate user header count upfront + if headers.len() > HEADERS_MAX { + return Err(Error::ResponseTooManyHeaders(HEADERS_MAX)); + } - inner.wbuf.write_commit(header_size); + inner.wbuf.clear(); let inner = self.inner.take().unwrap(); @@ -462,18 +443,24 @@ impl<'a, R: AsyncRead, W: AsyncWrite> Response<'a, R, W> { r: inner.r, w: RefCell::new(inner.w), rbuf: inner.rbuf, - wbuf: RefCell::new(LimitedRingBuffer { - inner: inner.wbuf, - limit: header_size, - }), - protocol: inner.protocol, - overflow: RefCell::new(None), + wbuf: RefCell::new(inner.wbuf), + protocol: RefCell::new(inner.protocol), end: Cell::new(false), }); let state = &state.inner; - Ok((ResponseHeader { state }, ResponsePrepareBody { state })) + Ok(( + ResponseHeader { + state, + code, + reason, + headers, + body_size, + offset: RefCell::new(0), + }, + ResponsePrepareBody { state }, + )) } } @@ -481,9 +468,8 @@ struct ResponseStateInner<'a, R: AsyncRead, W: AsyncWrite> { r: ReadHalf<'a, R>, w: RefCell>, rbuf: &'a mut VecRingBuffer, - wbuf: RefCell>, - protocol: protocol::ServerProtocol, - overflow: RefCell>, + wbuf: RefCell<&'a mut VecRingBuffer>, + protocol: RefCell, end: Cell, } @@ -499,39 +485,76 @@ impl Default for ResponseState<'_, R, W> { } } -pub struct ResponseHeader<'a, 'b, R: AsyncRead, W: AsyncWrite> { +pub struct ResponseHeader<'a, 'b, 'resp, R: AsyncRead, W: AsyncWrite> { state: &'b RefCell>>, + code: u16, + reason: &'resp str, + headers: &'resp [Header<'resp>], + body_size: BodySize, + offset: RefCell, } -impl<'a, 'b, R: AsyncRead, W: AsyncWrite> ResponseHeader<'a, 'b, R, W> { - #[allow(clippy::await_holding_refcell_ref)] +impl<'a, 'b, R: AsyncRead, W: AsyncWrite> ResponseHeader<'a, 'b, '_, R, W> { pub async fn send(self) -> Result, Error> { - // ok to hold across await as self.state is only ever immutably borrowed - let state = self.state.borrow(); - let state = state.as_ref().unwrap(); - - while state.wbuf.borrow().limit > 0 { - // ok to hold across await as this is the only place state.w is borrowed - let mut w = state.w.borrow_mut(); - - // TODO: vectored write - let size = w.write_shared(&state.wbuf).await?; + // Use AsyncOperation pattern like client + let result = AsyncOperation::new( + |cx| { + let state = self.state.borrow(); + let state = state.as_ref().unwrap(); + let mut w = state.w.borrow_mut(); + + // Keep trying to write until complete or would block + loop { + if !w.is_writable() { + // Writer not ready, try again later + return None; + } - let mut wbuf = state.wbuf.borrow_mut(); - wbuf.inner.read_commit(size); - wbuf.limit -= size; - } + let offset = *self.offset.borrow(); - let mut overflow = state.overflow.borrow_mut(); + match state.protocol.borrow_mut().send_response( + &mut StdWriteWrapper::new(Pin::new(&mut *w), cx), + self.code, + self.reason, + self.headers, + self.body_size, + offset, + ) { + Ok(None) => { + // Complete - headers sent successfully + return Some(Ok(())); + } + Ok(Some(bytes_written)) => { + // Partial write, update offset and try again immediately + *self.offset.borrow_mut() += bytes_written; + } + Err(e) => { + match e { + protocol::Error::Io(e) if e.kind() == io::ErrorKind::WouldBlock => { + // Writer not ready, try again later + return None; + } + _ => {} + } + + return Some(Err(e.into())); + } + } + } + }, + || { + let state = self.state.borrow(); + if let Some(state) = state.as_ref() { + state.w.borrow_mut().cancel(); + } + }, + ) + .await; - if let Some(overflow_ref) = &mut *overflow { - // Overflow is guaranteed to fit - let mut wbuf = state.wbuf.borrow_mut(); - wbuf.inner.write_all(overflow_ref.read_buf()).unwrap(); - *overflow = None; + match result { + Ok(()) => Ok(ResponseHeaderSent { state: self.state }), + Err(e) => Err(e), } - - Ok(ResponseHeaderSent { state: self.state }) } } @@ -551,34 +574,13 @@ impl ResponsePrepareBody<'_, '_, R, W> { } let wbuf = &mut *state.wbuf.borrow_mut(); - let overflow = &mut *state.overflow.borrow_mut(); // workaround for rust 1.77 #[allow(clippy::unused_io_amount)] - let accepted = if overflow.is_none() { - match wbuf.inner.write(src) { - Ok(size) => size, - Err(e) if e.kind() == io::ErrorKind::WriteZero => 0, - Err(e) => panic!("infallible buffer write failed: {}", e), - } - } else { - 0 - }; - - let (size, overflowed) = if accepted < src.len() { - // Only allow overflowing as much as there are header bytes left - let overflow = overflow.get_or_insert_with(|| ContiguousBuffer::new(wbuf.limit)); - - let remaining = &src[accepted..]; - let overflowed = match overflow.write(remaining) { - Ok(size) => size, - Err(e) if e.kind() == io::ErrorKind::WriteZero => 0, - Err(e) => panic!("infallible buffer write failed: {}", e), - }; - - (accepted + overflowed, overflowed) - } else { - (accepted, 0) + let size = match wbuf.write(src) { + Ok(size) => size, + Err(e) if e.kind() == io::ErrorKind::WriteZero => 0, + Err(e) => panic!("infallible buffer write failed: {}", e), }; assert!(size <= src.len()); @@ -587,7 +589,7 @@ impl ResponsePrepareBody<'_, '_, R, W> { state.end.set(true); } - Ok((size, overflowed)) + Ok((size, 0)) } } @@ -611,8 +613,8 @@ impl<'a, 'b, R: AsyncRead, W: AsyncWrite> ResponseHeaderSent<'a, 'b, R, W> { }), w: RefCell::new(ResponseBodyWrite { stream: state.w.into_inner(), - buf: wbuf.inner, - protocol: state.protocol, + buf: wbuf, + protocol: state.protocol.into_inner(), end: state.end.get(), }), })), @@ -843,6 +845,7 @@ mod tests { struct FakeStream { in_data: Vec, out_data: Vec, + write_limit: Option<(usize, usize)>, } impl FakeStream { @@ -850,6 +853,15 @@ mod tests { Self { in_data: Vec::new(), out_data: Vec::new(), + write_limit: None, + } + } + + fn with_write_limit(size: usize) -> Self { + Self { + in_data: Vec::new(), + out_data: Vec::new(), + write_limit: Some((size, size)), } } } @@ -881,15 +893,30 @@ mod tests { impl AsyncWrite for FakeStream { fn poll_write( mut self: Pin<&mut Self>, - _cx: &mut Context, + _cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - let size = self.out_data.write(buf).unwrap(); + let write_size = if let Some((left, max)) = self.write_limit.take() { + if left == 0 { + // Replenish and yield + self.write_limit = Some((max, max)); + return Poll::Pending; + } - Poll::Ready(Ok(size)) + let size = std::cmp::min(buf.len(), left); + self.write_limit = Some((left - size, max)); + + size + } else { + buf.len() + }; + + self.out_data.extend_from_slice(&buf[..write_size]); + + Poll::Ready(Ok(write_size)) } - fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } @@ -1145,88 +1172,6 @@ mod tests { assert!(fut.as_mut().poll(&mut cx).is_ready()); } - #[test] - fn response_overflow() { - let mut fut = pin!(async { - let mut stream = FakeStream::new(); - stream - .in_data - .write_all("GET /path HTTP/1.1\r\n\r\n".as_bytes()) - .unwrap(); - - let mut body = [0; 100]; - for i in 0..body.len() { - body[i] = b'a' + ((i as u8) % 26); - } - - let attempted_body = str::from_utf8(&body).unwrap(); - let expected_body = &attempted_body[..64]; - - { - let stream = RefCell::new(&mut stream); - - let tmp = Rc::new(TmpBuffer::new(64)); - let mut buf1 = VecRingBuffer::new(64, &tmp); - let mut buf2 = VecRingBuffer::new(64, &tmp); - - let (req, mut resp) = Request::new(io_split(&stream), &mut buf1, &mut buf2); - - let header = req.recv_header(&mut resp); - - let mut scratch = ParseScratch::::new(); - let (req_header, req_body) = header.recv(&mut scratch, None).await.unwrap(); - let req_body = req_body.discard_header(req_header); - drop(req_body); - - let mut state = ResponseState::default(); - - // This will serialize to 39 bytes, leaving 25 bytes left - let (header, mut prepare_body) = - match resp.prepare_header(200, "OK", &[], BodySize::Known(64), &mut state) { - Ok(ret) => ret, - Err(_) => unreachable!(), - }; - - // Only the first 64 bytes will fit - assert_eq!( - prepare_body - .prepare(attempted_body.as_bytes(), true) - .unwrap(), - (64, 39) - ); - - // End is ignored if input doesn't fit, so set end again - assert_eq!(prepare_body.prepare(&[], true).unwrap(), (0, 0)); - - let sent = header.send().await.unwrap(); - - let resp_body = sent.start_body(prepare_body); - - let size = match resp_body.send().await { - SendStatus::Partial(_, size) => size, - _ => unreachable!(), - }; - assert_eq!(size, 25); - - let finished = match resp_body.send().await { - SendStatus::Complete(finished) => finished, - _ => unreachable!(), - }; - - assert!(finished.is_persistent()); - } - - let expected = - "HTTP/1.1 200 OK\r\nContent-Length: 64\r\n\r\n".to_string() + expected_body; - - assert_eq!(str::from_utf8(&stream.out_data).unwrap(), expected); - }); - - let waker = Arc::new(NoopWaker).into(); - let mut cx = Context::from_waker(&waker); - assert!(fut.as_mut().poll(&mut cx).is_ready()); - } - #[test] fn dynamic_header_expansion_enabled() { let mut fut = pin!(async { @@ -1413,4 +1358,112 @@ mod tests { let mut cx = Context::from_waker(&waker); assert!(fut.as_mut().poll(&mut cx).is_ready()); } + + #[test] + fn response_partial_writes() { + let mut fut = pin!(async { + // Test that vectored response sending handles partial writes correctly + let tmp = Rc::new(TmpBuffer::new(1024)); + let mut buf1 = VecRingBuffer::new(256, &tmp); + let mut buf2 = VecRingBuffer::new(256, &tmp); + + // Use a stream that only writes 20 bytes at a time to force partial writes + let mut stream = FakeStream::with_write_limit(20); + + // Add a minimal request to read + stream + .in_data + .write_all("GET /path HTTP/1.1\r\n\r\n".as_bytes()) + .unwrap(); + + { + let stream_cell = RefCell::new(&mut stream); + let (req, mut resp) = Request::new(io_split(&stream_cell), &mut buf1, &mut buf2); + + // Read the request header first + let header = req.recv_header(&mut resp); + let mut scratch = ParseScratch::::new(); + let (req_header, req_body) = header.recv(&mut scratch, None).await.unwrap(); + let _req_body = req_body.discard_header(req_header); + + // Prepare response with multiple large headers to force substantial output + let headers = [ + Header { + name: "Server", + value: b"test-server-with-a-very-long-name-to-increase-header-size", + }, + Header { + name: "Content-Type", + value: b"application/json; charset=utf-8", + }, + Header { + name: "Cache-Control", + value: b"no-cache, no-store, must-revalidate, private, max-age=0", + }, + Header { + name: "X-Custom-Header-1", + value: b"very-long-custom-header-value-to-make-response-larger", + }, + Header { + name: "X-Custom-Header-2", + value: b"another-very-long-custom-header-value-for-testing", + }, + ]; + + let mut state = ResponseState::default(); + let (header, prepare_body) = resp + .prepare_header(200, "OK", &headers, BodySize::Known(6), &mut state) + .unwrap(); + + // This should handle multiple partial writes and eventually succeed + let sent = header.send().await.unwrap(); + + // Continue with body to complete the response + let _resp_body = sent.start_body(prepare_body); + } + + let output = String::from_utf8_lossy(&stream.out_data); + + println!( + "Partial write test output ({} bytes):\n{}", + stream.out_data.len(), + output + ); + + let expected = concat!( + "HTTP/1.1 200 OK\r\n", + "Server: test-server-with-a-very-long-name-to-increase-header-size\r\n", + "Content-Type: application/json; charset=utf-8\r\n", + "Cache-Control: no-cache, no-store, must-revalidate, private, max-age=0\r\n", + "X-Custom-Header-1: very-long-custom-header-value-to-make-response-larger\r\n", + "X-Custom-Header-2: another-very-long-custom-header-value-for-testing\r\n", + "Content-Length: 6\r\n", + "\r\n", + ); + + assert_eq!(output, expected); + }); + + let waker = Arc::new(NoopWaker).into(); + let mut cx = Context::from_waker(&waker); + + // Poll multiple times to handle partial writes + let mut poll_count = 0; + loop { + match fut.as_mut().poll(&mut cx) { + std::task::Poll::Ready(_) => break, + std::task::Poll::Pending => { + poll_count += 1; + if poll_count > 20 { + panic!( + "Too many polls ({}) - operation should have completed", + poll_count + ); + } + } + } + } + + println!("Completed after {} polls", poll_count + 1); + } } diff --git a/src/core/http1/util.rs b/src/core/http1/util.rs index 14bc0d5a3..72aa70867 100644 --- a/src/core/http1/util.rs +++ b/src/core/http1/util.rs @@ -16,7 +16,6 @@ use crate::core::buffer::{Buffer, VecRingBuffer}; use crate::core::io::{AsyncRead, AsyncReadExt}; -use std::cmp; use std::future::Future; use std::io; use std::pin::Pin; @@ -47,20 +46,6 @@ pub async fn recv_nonzero( Ok(()) } -pub struct LimitedRingBuffer<'a> { - pub inner: &'a mut VecRingBuffer, - pub limit: usize, -} - -impl AsRef<[u8]> for LimitedRingBuffer<'_> { - fn as_ref(&self) -> &[u8] { - let buf = Buffer::read_buf(self.inner); - let limit = cmp::min(buf.len(), self.limit); - - &buf[..limit] - } -} - pub struct AsyncOperation where C: FnMut(),