diff --git a/src/core/buffer.rs b/src/core/buffer.rs index cd80a274..69ec51cf 100644 --- a/src/core/buffer.rs +++ b/src/core/buffer.rs @@ -448,6 +448,60 @@ impl<'a: 'b, 'b> LimitBufsMut<'a, 'b> for [&'a mut [u8]] { } } +/// Calls `w.write` for each element in `bufs`. If `bufs` is empty, `w.write` is called once with +/// an empty slice. If this function encounters an error after successfully writing some bytes, +/// the number of bytes written so far is returned and the error is eaten. We assume the caller +/// can obtain the error by calling this function again afterwards to retrigger it. +/// +/// # Performance notes +/// +/// This function should only be used on `Write` implementations that don't already have an +/// optimized `write_vectored` implementation and for which `write` is a cheap operation, for a +/// example a memory buffer or a middleware with a buffering effect. Using it on something like +/// socket object, for which every `write` call may result in a system call, would defeat the +/// point. +pub fn write_trait_vectored_helper( + w: &mut W, + bufs: &[io::IoSlice], +) -> Result { + // Like std, if bufs is empty then we'll call write with no data + if bufs.is_empty() { + return w.write(&[]); + } + + let mut total = 0; + + for buf in bufs { + if buf.is_empty() { + continue; + } + + let size = loop { + match w.write(buf.as_ref()) { + Ok(size) => break size, + Err(e) if e.kind() == io::ErrorKind::Interrupted => continue, + Err(e) => { + if total > 0 { + // Return what we've written so far rather than returning the error. We + // can surface the error in the next call. + return Ok(total); + } + + return Err(e); + } + } + }; + + total += size; + + if size < buf.len() { + break; + } + } + + Ok(total) +} + pub struct ContiguousBuffer { buf: Vec, start: usize, @@ -540,6 +594,10 @@ impl Write for ContiguousBuffer { Ok(size) } + fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> Result { + write_trait_vectored_helper(self, bufs) + } + fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } @@ -724,6 +782,10 @@ impl + AsMut<[u8]>> Write for RingBuffer { Ok(pos) } + fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> Result { + write_trait_vectored_helper(self, bufs) + } + fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } @@ -1057,41 +1119,56 @@ mod tests { use super::*; use std::io::{Read, Write}; - #[test] - fn test_write_vectored_offset() { - struct MyWriter { - bufs: Vec, - } + struct MyWriter { + bufs: Vec, + cause_error_after: Option<(usize, io::Error)>, + } - impl MyWriter { - fn new() -> Self { - Self { bufs: Vec::new() } + impl MyWriter { + fn new() -> Self { + Self { + bufs: Vec::new(), + cause_error_after: None, } } - impl Write for MyWriter { - fn write(&mut self, buf: &[u8]) -> Result { - self.bufs.push(String::from_utf8(buf.to_vec()).unwrap()); + fn cause_error_after(&mut self, num_slices: usize, e: io::Error) { + self.cause_error_after = Some((num_slices, e)); + } + } + + impl Write for MyWriter { + fn write(&mut self, buf: &[u8]) -> Result { + if let Some((num_slices, e)) = self.cause_error_after.take() { + if num_slices == 0 { + return Err(e); + } - Ok(buf.len()) + self.cause_error_after = Some((num_slices - 1, e)); } - fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> Result { - let mut total = 0; + self.bufs.push(String::from_utf8(buf.to_vec()).unwrap()); - for buf in bufs { - total += buf.len(); - self.bufs.push(String::from_utf8(buf.to_vec()).unwrap()); - } + Ok(buf.len()) + } - Ok(total) - } + fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> Result { + let mut total = 0; - fn flush(&mut self) -> Result<(), io::Error> { - Ok(()) + for buf in bufs { + total += self.write(buf)?; } + + Ok(total) + } + + fn flush(&mut self) -> Result<(), io::Error> { + Ok(()) } + } + #[test] + fn test_write_vectored_offset() { // Empty let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[], 0); @@ -1144,6 +1221,78 @@ mod tests { assert_eq!(w.bufs[0], "anana"); } + #[test] + fn test_write_trait_vectored_helper() { + // Write none to get one empty + let mut w = MyWriter::new(); + assert_eq!(write_trait_vectored_helper(&mut w, &[]).unwrap(), 0); + assert_eq!(w.bufs, vec![""]); + + // Write multiple, skipping empty + let mut w = MyWriter::new(); + assert_eq!( + write_trait_vectored_helper( + &mut w, + &[ + io::IoSlice::new(b"apple"), + io::IoSlice::new(b"banana"), + io::IoSlice::new(b""), + io::IoSlice::new(b"cherry"), + ], + ) + .unwrap(), + 17 + ); + assert_eq!(w.bufs, vec!["apple", "banana", "cherry"]); + + // Error on first slice is returned + let mut w = MyWriter::new(); + w.cause_error_after(0, io::Error::from(io::ErrorKind::Other)); + write_trait_vectored_helper( + &mut w, + &[ + io::IoSlice::new(b"apple"), + io::IoSlice::new(b"banana"), + io::IoSlice::new(b"cherry"), + ], + ) + .unwrap_err(); + + // Error on later slice is eaten, and progress is returned + let mut w = MyWriter::new(); + w.cause_error_after(1, io::Error::from(io::ErrorKind::Other)); + assert_eq!( + write_trait_vectored_helper( + &mut w, + &[ + io::IoSlice::new(b"apple"), + io::IoSlice::new(b"banana"), + io::IoSlice::new(b"cherry"), + ], + ) + .unwrap(), + 5 + ); + assert_eq!(w.bufs, vec!["apple"]); + + // Interrupted error is eaten + let mut w = MyWriter::new(); + w.cause_error_after(1, io::Error::from(io::ErrorKind::Interrupted)); + assert_eq!( + write_trait_vectored_helper( + &mut w, + &[ + io::IoSlice::new(b"apple"), + io::IoSlice::new(b"banana"), + io::IoSlice::new(b"cherry"), + ], + ) + .unwrap(), + 17 + ); + assert_eq!(w.bufs, vec!["apple", "banana", "cherry"]); + } + #[test] fn test_buffer() { let mut b = ContiguousBuffer::new(8);