diff --git a/src/core/http1/client.rs b/src/core/http1/client.rs index e5a9411a..1eddbcd9 100644 --- a/src/core/http1/client.rs +++ b/src/core/http1/client.rs @@ -64,10 +64,13 @@ impl<'a, R: AsyncRead, W: AsyncWrite> Request<'a, R, W> { let size_limit = self.hbuf.capacity(); - let req_body = match req.send_header(self.hbuf, method, uri, headers, body_size, websocket) + let req_body = match req + .send_header(self.hbuf, method, uri, headers, body_size, websocket, 0) { - Ok(ret) => ret, - Err(_) => return Err(Error::RequestTooLarge(size_limit)), + protocol::SendHeaderStatus::Complete(req_body) => req_body, + protocol::SendHeaderStatus::Partial(_) | protocol::SendHeaderStatus::Error(_, _) => { + return Err(Error::RequestTooLarge(size_limit)) + } }; if self.bbuf.write_all(initial_body).is_err() { diff --git a/src/core/http1/protocol.rs b/src/core/http1/protocol.rs index 0adab19d..64780e52 100644 --- a/src/core/http1/protocol.rs +++ b/src/core/http1/protocol.rs @@ -19,9 +19,11 @@ #![allow(clippy::collapsible_else_if)] use crate::core::buffer::{cap_bufs, write_vectored_offset, FilledBuf, LimitBufs}; -use arrayvec::ArrayVec; +use crate::core::http1::util::HEADERS_MAX; +use arrayvec::{ArrayString, ArrayVec}; use std::cmp; use std::convert::TryFrom; +use std::fmt::Write as _; use std::io; use std::io::{Read, Write}; use std::mem; @@ -715,6 +717,7 @@ pub enum Error { pub struct ServerProtocol { state: ServerState, ver_min: u8, + header_size: Option, body_size: BodySize, chunk_left: Option, chunk_size: usize, @@ -729,6 +732,7 @@ impl<'buf, 'headers> ServerProtocol { Self { state: ServerState::ReceivingRequest, ver_min: 0, + header_size: None, body_size: BodySize::NoBody, chunk_left: None, chunk_size: 0, @@ -959,7 +963,8 @@ impl<'buf, 'headers> ServerProtocol { reason: &str, headers: &[Header], body_size: BodySize, - ) -> Result<(), Error> { + offset: usize, + ) -> Result, Error> { assert!( self.state == ServerState::AwaitingResponse || self.state == ServerState::ReceivingBody ); @@ -983,16 +988,42 @@ impl<'buf, 'headers> ServerProtocol { let chunked = body_size == BodySize::Unknown && self.ver_min >= 1; - if self.ver_min >= 1 { - writer.write_all(b"HTTP/1.1 ")?; - } else { - writer.write_all(b"HTTP/1.0 ")?; + // Validate user header count upfront + if headers.len() > HEADERS_MAX { + return Err(io::Error::new(io::ErrorKind::InvalidInput, "too many headers").into()); } - write!(writer, "{} {}\r\n", code, reason)?; + // Stack buffers for formatted parts + let mut status_code_str = ArrayString::<16>::new(); + let mut content_length_str = ArrayString::<48>::new(); + let final_crlf = b"\r\n"; + // Status line components as separate slices + let http_version_space = if self.ver_min >= 1 { + b"HTTP/1.1 " + } else { + b"HTTP/1.0 " + }; + let space = b" "; + write!(status_code_str, "{}", code).unwrap(); + let status_code_slice = status_code_str.as_bytes(); + let reason_bytes = reason.as_bytes(); + let crlf = b"\r\n"; + + // Build slice array using ArrayVec - panics on overflow (logic error) + const SLICES_MAX: usize = HEADERS_MAX * 4 + 5 + 4; // user headers + status line + (connection + content-length + transfer-encoding + final) + let mut slices: ArrayVec<&[u8], SLICES_MAX> = ArrayVec::new(); + + // Add status line as 5 separate slices (http_version_space + code + space + reason + crlf) + slices.push(http_version_space); + slices.push(status_code_slice); + slices.push(space); + slices.push(reason_bytes); + slices.push(crlf); + + // Add user headers (4 slices each: name, separator, value, newline) for h in headers.iter() { - // We'll override these headers + // Skip headers we'll override if (h.name.eq_ignore_ascii_case("Connection") && code != 101) || h.name.eq_ignore_ascii_case("Content-Length") || h.name.eq_ignore_ascii_case("Transfer-Encoding") @@ -1000,43 +1031,69 @@ impl<'buf, 'headers> ServerProtocol { continue; } - write!(writer, "{}: ", h.name)?; - writer.write_all(h.value)?; - writer.write_all(b"\r\n")?; + slices.push(h.name.as_bytes()); + slices.push(b": "); + slices.push(h.value); + slices.push(b"\r\n"); } - // Connection header - + // Add Connection header if needed if persistent && self.ver_min == 0 { - writer.write_all(b"Connection: keep-alive\r\n")?; + slices.push(b"Connection: keep-alive\r\n"); } else if !persistent && self.ver_min >= 1 { - writer.write_all(b"Connection: close\r\n")?; + slices.push(b"Connection: close\r\n"); } + // Add chunked connection header if chunked { - writer.write_all(b"Connection: Transfer-Encoding\r\n")?; + slices.push(b"Connection: Transfer-Encoding\r\n"); } - // Content-Length header - + // Add Content-Length header if needed if let BodySize::Known(x) = body_size { - write!(writer, "Content-Length: {}\r\n", x)?; + write!(content_length_str, "Content-Length: {}\r\n", x).unwrap(); + slices.push(content_length_str.as_bytes()); } - // Transfer-Encoding header - + // Add Transfer-Encoding header if needed if chunked { - writer.write_all(b"Transfer-Encoding: chunked\r\n")?; + slices.push(b"Transfer-Encoding: chunked\r\n"); } - writer.write_all(b"\r\n")?; + // Add final CRLF + slices.push(final_crlf); - self.state = ServerState::SendingBody; - self.body_size = body_size; - self.persistent = persistent; - self.chunked = chunked; + // Calculate total length of all slices + let mut total_length = 0; + for slice in &slices { + total_length += slice.len(); + } - Ok(()) + if let Some(header_size) = self.header_size { + if total_length != header_size { + // When resuming after a partial write, the caller must provide the same content + return Err(Error::Io(io::Error::from(io::ErrorKind::InvalidInput))); + } + } else { + self.header_size = Some(total_length); + } + + // Send slices using vectored write with offset + let bytes_written = write_vectored_offset(writer, &slices, offset)?; + + // Check if the entire header has been written + if offset + bytes_written >= total_length { + // Complete - update protocol state and return None + self.state = ServerState::SendingBody; + self.header_size = None; + self.body_size = body_size; + self.persistent = persistent; + self.chunked = chunked; + Ok(None) + } else { + // Partial write - return bytes written for retry + Ok(Some(bytes_written)) + } } pub fn send_body( @@ -1210,8 +1267,16 @@ impl ClientState { } } +pub enum SendHeaderStatus { + Partial(P), + Complete(C), + Error(P, E), +} + pub struct ClientRequest { state: ClientState, + header_size: Option, + offset: usize, } #[allow(clippy::new_without_default)] @@ -1219,9 +1284,12 @@ impl ClientRequest { pub fn new() -> Self { Self { state: ClientState::new(), + header_size: None, + offset: 0, } } + #[allow(clippy::too_many_arguments)] pub fn send_header( mut self, writer: &mut W, @@ -1230,7 +1298,8 @@ impl ClientRequest { headers: &[Header], body_size: BodySize, websocket: bool, - ) -> Result { + offset: usize, + ) -> SendHeaderStatus { let body_size = if websocket { BodySize::NoBody } else { @@ -1239,10 +1308,41 @@ impl ClientRequest { let chunked = body_size == BodySize::Unknown; - write!(writer, "{} {} HTTP/1.1\r\n", method, uri)?; + // Validate user header count upfront + if headers.len() > HEADERS_MAX { + return SendHeaderStatus::Error( + ClientRequest { + state: self.state, + header_size: None, + offset: 0, + }, + io::Error::new(io::ErrorKind::InvalidInput, "too many headers").into(), + ); + } + + // Stack buffers for formatted parts + let mut content_length_str = ArrayString::<48>::new(); + let final_crlf = b"\r\n"; + + // Request line components as separate slices + let method_bytes = method.as_bytes(); + let space = b" "; + let uri_bytes = uri.as_bytes(); + let http_version_crlf = b" HTTP/1.1\r\n"; + + // Build slice array using ArrayVec - panics on overflow (logic error) + const SLICES_MAX: usize = HEADERS_MAX * 4 + 4 + 4; // user headers + request line + (connection + content-length + transfer-encoding + final) + let mut slices: ArrayVec<&[u8], SLICES_MAX> = ArrayVec::new(); + + // Add request line as 4 separate slices (method + space + uri + http_version_crlf) + slices.push(method_bytes); + slices.push(space); + slices.push(uri_bytes); + slices.push(http_version_crlf); + // Add user headers (4 slices each: name, separator, value, newline) for h in headers.iter() { - // We'll override these headers + // Skip headers we'll override if (h.name.eq_ignore_ascii_case("Connection") && !websocket) || h.name.eq_ignore_ascii_case("Content-Length") || h.name.eq_ignore_ascii_case("Transfer-Encoding") @@ -1250,41 +1350,88 @@ impl ClientRequest { continue; } - write!(writer, "{}: ", h.name)?; - writer.write_all(h.value)?; - writer.write_all(b"\r\n")?; + slices.push(h.name.as_bytes()); + slices.push(b": "); + slices.push(h.value); + slices.push(b"\r\n"); } - // Connection header - + // Add chunked connection header if chunked { - writer.write_all(b"Connection: Transfer-Encoding\r\n")?; + slices.push(b"Connection: Transfer-Encoding\r\n"); } - // Content-Length header - + // Add Content-Length header if needed if let BodySize::Known(x) = body_size { if x > 0 || !method.eq_ignore_ascii_case("OPTIONS") && !method.eq_ignore_ascii_case("GET") && !method.eq_ignore_ascii_case("HEAD") { - write!(writer, "Content-Length: {}\r\n", x)?; + write!(content_length_str, "Content-Length: {}\r\n", x).unwrap(); + slices.push(content_length_str.as_bytes()); } } - // Transfer-Encoding header - + // Add Transfer-Encoding header if needed if chunked { - writer.write_all(b"Transfer-Encoding: chunked\r\n")?; + slices.push(b"Transfer-Encoding: chunked\r\n"); + } + + // Add final CRLF + slices.push(final_crlf); + + // Calculate total length of all slices + let mut total_length = 0; + for slice in &slices { + total_length += slice.len(); + } + + if let Some(header_size) = self.header_size { + if total_length != header_size { + // When resuming after a partial write, the caller must provide the same content + return SendHeaderStatus::Error( + ClientRequest { + state: self.state, + header_size: Some(total_length), + offset: self.offset, + }, + io::Error::from(io::ErrorKind::InvalidInput).into(), + ); + } } - writer.write_all(b"\r\n")?; + // Send slices using vectored write with offset + let bytes_written = match write_vectored_offset(writer, &slices, offset) { + Ok(bytes) => bytes, + Err(e) => { + return SendHeaderStatus::Error( + ClientRequest { + state: self.state, + header_size: Some(total_length), + offset: self.offset, + }, + e.into(), + ); + } + }; - self.state.body_size = body_size; - self.state.chunked = chunked; + self.offset = offset + bytes_written; - Ok(ClientRequestBody { state: self.state }) + // Check if the entire header has been written + if self.offset >= total_length { + // Complete - update state and return ClientRequestBody + self.state.body_size = body_size; + self.state.chunked = chunked; + SendHeaderStatus::Complete(ClientRequestBody { state: self.state }) + } else { + // Partial write - return bytes written for retry + SendHeaderStatus::Partial(ClientRequest { + state: self.state, + header_size: Some(total_length), + offset: self.offset, + }) + } } } @@ -1779,6 +1926,7 @@ pub struct ClientFinished { #[cfg(test)] mod tests { use super::*; + use crate::core::buffer::write_trait_vectored_helper; const HEADERS_MAX: usize = 32; @@ -1812,27 +1960,7 @@ mod tests { } fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> Result { - let mut total = 0; - - for buf in bufs { - let size = match self.write(buf.as_ref()) { - Ok(size) => size, - Err(e) => { - if e.kind() == io::ErrorKind::WriteZero && total > 0 { - return Ok(total); - } - return Err(e); - } - }; - - total += size; - - if size < buf.len() { - break; - } - } - - Ok(total) + write_trait_vectored_helper(self, bufs) } fn flush(&mut self) -> Result<(), io::Error> { @@ -1975,8 +2103,10 @@ mod tests { BodySize::Known(resp.body.len()) }; - p.send_response(&mut wbuf, resp.code, &resp.reason, &headers, body_size) + let partial = p + .send_response(&mut wbuf, resp.code, &resp.reason, &headers, body_size, 0) .unwrap(); + assert!(partial.is_none()); let size = wbuf.position() as usize; @@ -2865,6 +2995,7 @@ mod tests { let mut p = ServerProtocol { state: ServerState::ReceivingBody, ver_min: 0, + header_size: None, body_size: test.body_size, chunk_left: test.chunk_left, chunk_size: test.chunk_size, @@ -2935,7 +3066,7 @@ mod tests { body_size: BodySize, ver_min: u8, persistent: bool, - result: Result<(), Error>, + result: Result, Error>, state: ServerState, body_size_after: BodySize, chunked: bool, @@ -2952,11 +3083,11 @@ mod tests { body_size: BodySize::Known(0), ver_min: 1, persistent: false, - result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), + result: Ok(Some(5)), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, - written: "", + written: "HTTP/", }, Test { name: "cant-write-1.0", @@ -2967,11 +3098,11 @@ mod tests { body_size: BodySize::Known(0), ver_min: 0, persistent: false, - result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), + result: Ok(Some(5)), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, - written: "", + written: "HTTP/", }, Test { name: "cant-write-status-line", @@ -2982,7 +3113,7 @@ mod tests { body_size: BodySize::Known(0), ver_min: 0, persistent: false, - result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), + result: Ok(Some(12)), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, @@ -2999,7 +3130,7 @@ mod tests { body_size: BodySize::Known(0), ver_min: 0, persistent: false, - result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), + result: Ok(Some(20)), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, @@ -3016,11 +3147,11 @@ mod tests { body_size: BodySize::Known(0), ver_min: 0, persistent: false, - result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), + result: Ok(Some(24)), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, - written: "HTTP/1.0 200 OK\r\nFoo: ", + written: "HTTP/1.0 200 OK\r\nFoo: Ba", }, Test { name: "cant-write-header-eol", @@ -3033,56 +3164,56 @@ mod tests { body_size: BodySize::Known(0), ver_min: 0, persistent: false, - result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), + result: Ok(Some(26)), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, - written: "HTTP/1.0 200 OK\r\nFoo: Bar", + written: "HTTP/1.0 200 OK\r\nFoo: Bar\r", }, Test { name: "cant-write-keep-alive", - write_space: 26, + write_space: 30, code: 200, reason: "OK", headers: &[], body_size: BodySize::Known(0), ver_min: 0, persistent: true, - result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), + result: Ok(Some(30)), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, - written: "HTTP/1.0 200 OK\r\n", + written: "HTTP/1.0 200 OK\r\nConnection: k", }, Test { name: "cant-write-close", - write_space: 26, + write_space: 30, code: 200, reason: "OK", headers: &[], body_size: BodySize::Known(0), ver_min: 1, persistent: false, - result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), + result: Ok(Some(30)), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, - written: "HTTP/1.1 200 OK\r\n", + written: "HTTP/1.1 200 OK\r\nConnection: c", }, Test { name: "cant-write-transfer-encoding", - write_space: 26, + write_space: 50, code: 200, reason: "OK", headers: &[], body_size: BodySize::Unknown, ver_min: 1, persistent: false, - result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), + result: Ok(Some(50)), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, - written: "HTTP/1.1 200 OK\r\n", + written: "HTTP/1.1 200 OK\r\nConnection: close\r\nConnection: Tr", }, Test { name: "cant-write-content-length", @@ -3093,11 +3224,11 @@ mod tests { body_size: BodySize::Known(0), ver_min: 0, persistent: false, - result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), + result: Ok(Some(26)), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, - written: "HTTP/1.0 200 OK\r\n", + written: "HTTP/1.0 200 OK\r\nContent-L", }, Test { name: "cant-write-te-chunked", @@ -3108,11 +3239,11 @@ mod tests { body_size: BodySize::Unknown, ver_min: 1, persistent: true, - result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), + result: Ok(Some(50)), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, - written: "HTTP/1.1 200 OK\r\nConnection: Transfer-Encoding\r\n", + written: "HTTP/1.1 200 OK\r\nConnection: Transfer-Encoding\r\nTr", }, Test { name: "cant-write-eol", @@ -3123,11 +3254,11 @@ mod tests { body_size: BodySize::Unknown, ver_min: 0, persistent: false, - result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), + result: Ok(Some(18)), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, - written: "HTTP/1.0 200 OK\r\n", + written: "HTTP/1.0 200 OK\r\n\r", }, Test { name: "exclude-headers", @@ -3143,7 +3274,7 @@ mod tests { body_size: BodySize::Unknown, ver_min: 0, persistent: false, - result: Ok(()), + result: Ok(None), state: ServerState::SendingBody, body_size_after: BodySize::Unknown, chunked: false, @@ -3163,7 +3294,7 @@ mod tests { body_size: BodySize::NoBody, ver_min: 0, persistent: false, - result: Ok(()), + result: Ok(None), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, @@ -3178,7 +3309,7 @@ mod tests { body_size: BodySize::NoBody, ver_min: 0, persistent: false, - result: Ok(()), + result: Ok(None), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, @@ -3193,7 +3324,7 @@ mod tests { body_size: BodySize::Known(42), ver_min: 0, persistent: false, - result: Ok(()), + result: Ok(None), state: ServerState::SendingBody, body_size_after: BodySize::Known(42), chunked: false, @@ -3208,7 +3339,7 @@ mod tests { body_size: BodySize::Unknown, ver_min: 0, persistent: false, - result: Ok(()), + result: Ok(None), state: ServerState::SendingBody, body_size_after: BodySize::Unknown, chunked: false, @@ -3223,7 +3354,7 @@ mod tests { body_size: BodySize::NoBody, ver_min: 1, persistent: true, - result: Ok(()), + result: Ok(None), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, @@ -3238,7 +3369,7 @@ mod tests { body_size: BodySize::Known(42), ver_min: 1, persistent: true, - result: Ok(()), + result: Ok(None), state: ServerState::SendingBody, body_size_after: BodySize::Known(42), chunked: false, @@ -3253,7 +3384,7 @@ mod tests { body_size: BodySize::Unknown, ver_min: 1, persistent: true, - result: Ok(()), + result: Ok(None), state: ServerState::SendingBody, body_size_after: BodySize::Unknown, chunked: true, @@ -3268,7 +3399,7 @@ mod tests { body_size: BodySize::NoBody, ver_min: 0, persistent: true, - result: Ok(()), + result: Ok(None), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, @@ -3283,7 +3414,7 @@ mod tests { body_size: BodySize::NoBody, ver_min: 0, persistent: false, - result: Ok(()), + result: Ok(None), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, @@ -3298,7 +3429,7 @@ mod tests { body_size: BodySize::NoBody, ver_min: 1, persistent: true, - result: Ok(()), + result: Ok(None), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, @@ -3313,7 +3444,7 @@ mod tests { body_size: BodySize::NoBody, ver_min: 1, persistent: false, - result: Ok(()), + result: Ok(None), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, @@ -3328,7 +3459,7 @@ mod tests { body_size: BodySize::Known(42), ver_min: 0, persistent: false, - result: Ok(()), + result: Ok(None), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, @@ -3340,6 +3471,7 @@ mod tests { let mut p = ServerProtocol { state: ServerState::AwaitingResponse, ver_min: test.ver_min, + header_size: None, body_size: BodySize::NoBody, chunk_left: None, chunk_size: 0, @@ -3348,16 +3480,25 @@ mod tests { sending_chunk: None, }; - let mut w = MyBuffer::new(test.write_space, false); + let mut w = MyBuffer::new(test.write_space, true); - let r = p.send_response(&mut w, test.code, test.reason, test.headers, test.body_size); + let r = p.send_response( + &mut w, + test.code, + test.reason, + test.headers, + test.body_size, + 0, + ); match r { - Ok(_) => { - match &test.result { - Ok(_) => {} + Ok(partial) => { + let expected = match test.result { + Ok(p) => p, _ => panic!("result mismatch: test={}", test.name), }; + + assert_eq!(partial, expected, "test={}", test.name); } Err(e) => { let expected = match &test.result { @@ -3627,6 +3768,7 @@ mod tests { let mut p = ServerProtocol { state: ServerState::SendingBody, ver_min: 0, + header_size: None, body_size: test.body_size, chunk_left: None, chunk_size: 0, @@ -3902,7 +4044,7 @@ mod tests { headers: &'headers [Header<'buf>], body_size: BodySize, websocket: bool, - result: Result<(), Error>, + result: Result, Error>, body_size_after: BodySize, chunked: bool, written: &'static str, @@ -3917,10 +4059,10 @@ mod tests { headers: &[], body_size: BodySize::Known(0), websocket: false, - result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), + result: Ok(Some(2)), body_size_after: BodySize::Known(0), chunked: false, - written: "", + written: "GE", }, Test { name: "cant-write-request-line", @@ -3930,10 +4072,10 @@ mod tests { headers: &[], body_size: BodySize::Known(0), websocket: false, - result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), + result: Ok(Some(12)), body_size_after: BodySize::Known(0), chunked: false, - written: "GET /foo", + written: "GET /foo HTT", }, Test { name: "cant-write-header-name", @@ -3946,7 +4088,7 @@ mod tests { }], body_size: BodySize::Known(0), websocket: false, - result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), + result: Ok(Some(22)), body_size_after: BodySize::Known(0), chunked: false, written: "GET /foo HTTP/1.1\r\nFoo", @@ -3962,10 +4104,10 @@ mod tests { }], body_size: BodySize::Known(0), websocket: false, - result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), + result: Ok(Some(26)), body_size_after: BodySize::Known(0), chunked: false, - written: "GET /foo HTTP/1.1\r\nFoo: ", + written: "GET /foo HTTP/1.1\r\nFoo: Ba", }, Test { name: "cant-write-header-eol", @@ -3978,36 +4120,36 @@ mod tests { }], body_size: BodySize::Known(0), websocket: false, - result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), + result: Ok(Some(28)), body_size_after: BodySize::Known(0), chunked: false, - written: "GET /foo HTTP/1.1\r\nFoo: Bar", + written: "GET /foo HTTP/1.1\r\nFoo: Bar\r", }, Test { name: "cant-write-transfer-encoding", - write_space: 27, + write_space: 33, method: "POST", uri: "/foo", headers: &[], body_size: BodySize::Unknown, websocket: false, - result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), + result: Ok(Some(33)), body_size_after: BodySize::Unknown, chunked: false, - written: "POST /foo HTTP/1.1\r\n", + written: "POST /foo HTTP/1.1\r\nConnection: T", }, Test { name: "cant-write-content-length", - write_space: 27, + write_space: 29, method: "POST", uri: "/foo", headers: &[], body_size: BodySize::Known(0), websocket: false, - result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), + result: Ok(Some(29)), body_size_after: BodySize::Known(0), chunked: false, - written: "POST /foo HTTP/1.1\r\n", + written: "POST /foo HTTP/1.1\r\nContent-L", }, Test { name: "cant-write-eol", @@ -4017,7 +4159,7 @@ mod tests { headers: &[], body_size: BodySize::Unknown, websocket: false, - result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), + result: Ok(Some(20)), body_size_after: BodySize::Unknown, chunked: false, written: "POST /foo HTTP/1.1\r\n", @@ -4047,7 +4189,7 @@ mod tests { ], body_size: BodySize::Known(0), websocket: false, - result: Ok(()), + result: Ok(None), body_size_after: BodySize::Known(0), chunked: false, written: "POST /foo HTTP/1.1\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n", @@ -4060,7 +4202,7 @@ mod tests { headers: &[], body_size: BodySize::NoBody, websocket: false, - result: Ok(()), + result: Ok(None), body_size_after: BodySize::NoBody, chunked: false, written: "GET /foo HTTP/1.1\r\n\r\n", @@ -4073,7 +4215,7 @@ mod tests { headers: &[], body_size: BodySize::Known(42), websocket: false, - result: Ok(()), + result: Ok(None), body_size_after: BodySize::Known(42), chunked: false, written: "POST /foo HTTP/1.1\r\nContent-Length: 42\r\n\r\n", @@ -4086,7 +4228,7 @@ mod tests { headers: &[], body_size: BodySize::Unknown, websocket: false, - result: Ok(()), + result: Ok(None), body_size_after: BodySize::Unknown, chunked: true, written: "POST /foo HTTP/1.1\r\nConnection: Transfer-Encoding\r\nTransfer-Encoding: chunked\r\n\r\n", @@ -4099,7 +4241,7 @@ mod tests { headers: &[], body_size: BodySize::Known(42), websocket: true, - result: Ok(()), + result: Ok(None), body_size_after: BodySize::NoBody, chunked: false, written: "GET /foo HTTP/1.1\r\n\r\n", @@ -4109,30 +4251,41 @@ mod tests { for test in tests.iter() { let req = ClientRequest::new(); - let mut w = MyBuffer::new(test.write_space, false); + let mut w = MyBuffer::new(test.write_space, true); - let r = req.send_header( + let r = match req.send_header( &mut w, test.method, test.uri, test.headers, test.body_size, test.websocket, - ); + 0, + ) { + SendHeaderStatus::Complete(req_body) => Ok((Some(req_body), None)), + SendHeaderStatus::Partial(req) => Ok((None, Some(req.offset))), + SendHeaderStatus::Error(_, e) => Err(e), + }; match r { - Ok(req_body) => { - match &test.result { - Ok(_) => {} + Ok((ret_req_body, ret_partial)) => { + let expected_partial = match test.result { + Ok(p) => p, _ => panic!("result mismatch: test={}", test.name), }; - assert_eq!( - req_body.state.body_size, test.body_size_after, - "test={}", - test.name - ); - assert_eq!(req_body.state.chunked, test.chunked, "test={}", test.name); + if let Some(expected_partial) = expected_partial { + let partial = ret_partial.unwrap(); + assert_eq!(partial, expected_partial, "test={}", test.name); + } else { + let req_body = ret_req_body.unwrap(); + assert_eq!( + req_body.state.body_size, test.body_size_after, + "test={}", + test.name + ); + assert_eq!(req_body.state.chunked, test.chunked, "test={}", test.name); + } } Err(e) => { let expected = match &test.result { @@ -5290,19 +5443,21 @@ mod tests { let mut out = MyBuffer::new(1024, true); - let req_body = req - .send_header( - &mut out, - "GET", - "/foo", - &[Header { - name: "Host", - value: b"example.com", - }], - BodySize::NoBody, - false, - ) - .unwrap(); + let req_body = match req.send_header( + &mut out, + "GET", + "/foo", + &[Header { + name: "Host", + value: b"example.com", + }], + BodySize::NoBody, + false, + 0, + ) { + SendHeaderStatus::Complete(req_body) => req_body, + _ => panic!("unexpected status"), + }; let expected = "GET /foo HTTP/1.1\r\nHost: example.com\r\n\r\n"; diff --git a/src/core/http1/server.rs b/src/core/http1/server.rs index 23056938..627bd6ca 100644 --- a/src/core/http1/server.rs +++ b/src/core/http1/server.rs @@ -438,7 +438,7 @@ impl<'a, R: AsyncRead, W: AsyncWrite> Response<'a, R, W> { if inner .protocol - .send_response(&mut buf, code, reason, headers, body_size) + .send_response(&mut buf, code, reason, headers, body_size, 0) .is_err() { // enable prepare_header to be called again