diff --git a/src/client/body/multipart.rs b/src/client/body/multipart.rs index 2ec7ca14..d43a2195 100644 --- a/src/client/body/multipart.rs +++ b/src/client/body/multipart.rs @@ -12,26 +12,61 @@ use crate::{client::body::PyStream, error::Error, header::HeaderMap}; /// A multipart form for a request. #[pyclass(subclass)] -pub struct Multipart(pub Option); +pub struct Multipart { + pub form: Option, + pub parts: Vec, +} + +/// The data for a part value of a multipart form. +#[derive(FromPyObject)] +pub enum Value { + Text(PyBackedStr), + Bytes(PyBackedBytes), + File(PathBuf), + Stream(PyStream), +} + +/// A part of a multipart form. +#[pyclass(subclass)] +pub struct Part { + pub name: String, + pub value: Option, + pub filename: Option, + pub mime: Option, + pub length: Option, + pub headers: Option, +} + +// ===== impl Multipart ===== #[pymethods] impl Multipart { - /// Creates a new multipart form. + /// Creates a new multipart. #[new] #[pyo3(signature = (*parts))] - pub fn new(parts: &Bound) -> PyResult { - let mut form = multipart::Form::new(); + pub fn new(py: Python, parts: &Bound) -> PyResult { + let mut new_parts = Vec::with_capacity(parts.len()); for part in parts { let part = part.cast::()?; let mut part = part.borrow_mut(); - form = part - .name - .take() - .zip(part.inner.take()) - .map(|(name, inner)| form.part(name, inner)) - .ok_or_else(|| Error::Memory)?; + new_parts.push(part.try_clone(py)?); + } + + Ok(Self { + form: None, + parts: new_parts, + }) + } +} + +impl Multipart { + fn build_form(&mut self, py: Python) -> PyResult { + let mut form = multipart::Form::new(); + for part in &mut self.parts { + let (name, inner) = part.build_form_part(py)?; + form = form.part(name, inner); } - Ok(Multipart(Some(form))) + Ok(form) } } @@ -40,31 +75,120 @@ impl FromPyObject<'_, '_> for Multipart { fn extract(ob: Borrowed) -> PyResult { let multipart = ob.cast::()?; - multipart - .borrow_mut() - .0 - .take() - .map(Some) - .map(Self) - .ok_or_else(|| Error::Memory) - .map_err(Into::into) + let mut multipart = multipart.borrow_mut(); + let form = multipart.build_form(ob.py())?; + + Ok(Multipart { + form: Some(form), + parts: Vec::new(), + }) } } -/// A part of a multipart form. -#[pyclass(subclass)] -pub struct Part { - pub name: Option, - pub inner: Option, +// ===== impl Value ===== + +impl Value { + fn try_clone(&self, py: Python) -> Option { + match self { + Value::Text(text) => { + let text = text.clone_ref(py); + Some(Value::Text(text)) + } + Value::Bytes(bytes) => { + let bytes = bytes.clone_ref(py); + Some(Value::Bytes(bytes)) + } + Value::File(path) => { + let path = path.clone(); + Some(Value::File(path)) + } + Value::Stream(_) => None, + } + } } -/// The data for a part value of a multipart form. -#[derive(FromPyObject)] -pub enum Value { - Text(PyBackedStr), - Bytes(PyBackedBytes), - File(PathBuf), - Stream(PyStream), +// ===== impl Part ===== + +impl Part { + fn with_value(&self, value: Value) -> Part { + Part { + name: self.name.clone(), + value: Some(value), + filename: self.filename.clone(), + mime: self.mime.clone(), + length: self.length, + headers: self.headers.clone(), + } + } + + fn build_inner(value: Value, length: Option) -> Result { + Ok(match value { + Value::Text(text) => multipart::Part::stream(Body::from(Bytes::from_owner(text))), + Value::Bytes(bytes) => multipart::Part::stream(Body::from(Bytes::from_owner(bytes))), + Value::File(path) => pyo3_async_runtimes::tokio::get_runtime() + .block_on(multipart::Part::file(path)) + .map_err(Error::from)?, + Value::Stream(stream) => { + let stream = Body::wrap_stream(stream); + match length { + Some(length) => multipart::Part::stream_with_length(stream, length), + None => multipart::Part::stream(stream), + } + } + }) + } + + fn clone_value_or_take(&mut self, py: Python) -> PyResult { + self.value + .as_ref() + .and_then(|value| value.try_clone(py)) + .or_else(|| self.value.take()) + .ok_or_else(|| Error::Memory.into()) + } + + fn build_form_part(&mut self, py: Python) -> PyResult<(String, multipart::Part)> { + let value = self.clone_value_or_take(py)?; + let name = self.name.clone(); + let filename = self.filename.clone(); + let mime = self.mime.clone(); + let length = self.length; + let headers = self.headers.clone(); + + py.detach(move || { + let mut inner = Self::build_inner(value, length)?; + + if let Some(filename) = filename { + inner = inner.file_name(filename); + } + + if let Some(mime) = mime { + inner = inner.mime_str(&mime).map_err(Error::Library)?; + } + + if let Some(headers) = headers { + inner = inner.headers(headers.0); + } + + Ok((name, inner)) + }) + } + + fn try_clone(&mut self, py: Python) -> PyResult { + if let Some(part) = self + .value + .as_ref() + .and_then(|value| value.try_clone(py)) + .map(|value| self.with_value(value)) + { + return Ok(part); + } + + self.value + .take() + .map(|value| self.with_value(value)) + .ok_or_else(|| Error::Memory) + .map_err(Into::into) + } } #[pymethods] @@ -80,52 +204,20 @@ impl Part { headers = None ))] pub fn new( - py: Python, name: String, value: Value, filename: Option, mime: Option<&str>, length: Option, headers: Option, - ) -> PyResult { - py.detach(|| { - // Create the inner part - let mut inner = match value { - Value::Text(text) => multipart::Part::stream(Body::from(Bytes::from_owner(text))), - Value::Bytes(bytes) => { - multipart::Part::stream(Body::from(Bytes::from_owner(bytes))) - } - Value::File(path) => pyo3_async_runtimes::tokio::get_runtime() - .block_on(multipart::Part::file(path)) - .map_err(Error::from)?, - Value::Stream(stream) => { - let stream = Body::wrap_stream(stream); - match length { - Some(length) => multipart::Part::stream_with_length(stream, length), - None => multipart::Part::stream(stream), - } - } - }; - - // Set the filename and MIME type if provided - if let Some(filename) = filename { - inner = inner.file_name(filename); - } - - // Set the MIME type if provided - if let Some(mime) = mime { - inner = inner.mime_str(mime).map_err(Error::Library)?; - } - - // Set the headers if provided - if let Some(headers) = headers { - inner = inner.headers(headers.0); - } - - Ok(Part { - name: Some(name), - inner: Some(inner), - }) - }) + ) -> Part { + Part { + name, + value: Some(value), + filename, + mime: mime.map(ToOwned::to_owned), + length, + headers, + } } } diff --git a/src/client/req.rs b/src/client/req.rs index cee104f7..d9553a0a 100644 --- a/src/client/req.rs +++ b/src/client/req.rs @@ -396,7 +396,7 @@ where apply_option!( set_if_some, builder, - request.multipart.and_then(|form| form.0), + request.multipart.and_then(|form| form.form), multipart ); apply_option!( diff --git a/tests/multipart_test.py b/tests/multipart_test.py new file mode 100644 index 00000000..b251787a --- /dev/null +++ b/tests/multipart_test.py @@ -0,0 +1,90 @@ +from pathlib import Path + +import pytest +import wreq +from wreq import Multipart, Part + +client = wreq.Client(tls_info=True) + + +def assert_form_value(data, key, expected): + value = data["form"][key] + if isinstance(value, list): + assert expected in value + else: + assert value == expected + + +@pytest.mark.asyncio +@pytest.mark.flaky(reruns=3, reruns_delay=2) +async def test_reuse_multipart_with_clonable_parts(): + form = Multipart( + Part(name="a", value="1"), + Part(name="b", value=b"2"), + Part(name="c", value=Path("./README.md"), filename="README.md", mime="text/plain"), + ) + + for _ in range(3): + resp = await client.post("https://httpbin.io/post", multipart=form) + async with resp: + assert resp.status.is_success() + data = await resp.json() + assert_form_value(data, "a", "1") + assert_form_value(data, "b", "2") + assert "c" in data["files"] + + +@pytest.mark.asyncio +@pytest.mark.flaky(reruns=3, reruns_delay=2) +async def test_stream_part_is_one_shot_when_reusing_multipart(): + def file_stream(path): + with open(path, "rb") as f: + while chunk := f.read(1024): + yield chunk + + form = Multipart( + Part( + name="stream", + value=file_stream("./README.md"), + filename="README.md", + mime="text/plain", + ), + ) + + resp = await client.post("https://httpbin.io/post", multipart=form) + async with resp: + assert resp.status.is_success() + + with pytest.raises(RuntimeError): + resp = await client.post("https://httpbin.io/post", multipart=form) + async with resp: + pass + + +@pytest.mark.asyncio +@pytest.mark.flaky(reruns=3, reruns_delay=2) +async def test_reuse_same_part_without_copy_for_clonable_value(): + part = Part(name="a", value="1") + + form1 = Multipart(part) + form2 = Multipart(part) + + for form in (form1, form2): + resp = await client.post("https://httpbin.io/post", multipart=form) + async with resp: + assert resp.status.is_success() + data = await resp.json() + assert_form_value(data, "a", "1") + + +@pytest.mark.asyncio +@pytest.mark.flaky(reruns=3, reruns_delay=2) +async def test_reuse_same_part_without_copy_fails_for_stream_value(): + def bytes_stream(): + yield b"hello" + + part = Part(name="stream", value=bytes_stream()) + Multipart(part) + + with pytest.raises(RuntimeError): + Multipart(part)