diff --git a/CHANGES/12340.feature.rst b/CHANGES/12340.feature.rst new file mode 100644 index 00000000000..ef97fc49f81 --- /dev/null +++ b/CHANGES/12340.feature.rst @@ -0,0 +1,3 @@ +Added the possibility to provide a callback to the ``Payload``, +which is used by their writer methods to report back the already written bytes. +-- by :user:`mib1185`. diff --git a/aiohttp/payload.py b/aiohttp/payload.py index 9a8dc2f3262..e749cbb5c1e 100644 --- a/aiohttp/payload.py +++ b/aiohttp/payload.py @@ -7,7 +7,7 @@ import sys import warnings from abc import ABC, abstractmethod -from collections.abc import AsyncIterable, AsyncIterator, Iterable +from collections.abc import AsyncIterable, AsyncIterator, Callable, Iterable from itertools import chain from typing import IO, Any, Final, TextIO @@ -151,11 +151,13 @@ def __init__( content_type: None | str | _SENTINEL = sentinel, filename: str | None = None, encoding: str | None = None, + progress: Callable[[int], None] | None = None, **kwargs: Any, ) -> None: self._encoding = encoding self._filename = filename self._headers = CIMultiDict[str]() + self._progress = progress self._value = value if content_type is not sentinel and content_type is not None: assert isinstance(content_type, str) @@ -240,6 +242,19 @@ def set_content_disposition( disptype, quote_fields=quote_fields, _charset=_charset, params=params ) + def set_progress_callback( + self, callback: Callable[[int], None] | None = None + ) -> None: + """ + Set a callback function to be called with the total number of bytes written so far. + + Args: + callback: A callable that takes an integer representing the total number of bytes written so far. + When set to `None`, it will clear any existing progress callback. + + """ + self._progress = callback + @abstractmethod def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: """ @@ -334,6 +349,11 @@ async def close(self) -> None: """ self._close() + def _report_progress(self, total_written_len: int) -> None: + """Call the progress callback if it is set, with the total number of bytes written so far.""" + if self._progress: + self._progress(total_written_len) + class BytesPayload(Payload): _value: bytes @@ -408,10 +428,13 @@ async def write_with_length( is performed efficiently using array slicing. """ + self._report_progress(0) if content_length is not None: await writer.write(self._value[:content_length]) + self._report_progress(content_length) else: await writer.write(self._value) + self._report_progress(len(self._value)) class StringPayload(BytesPayload): @@ -598,6 +621,8 @@ async def write_with_length( total_written_len = 0 remaining_content_len = content_length + self._report_progress(total_written_len) + # Get initial data and available length available_len, chunk = await loop.run_in_executor( None, self._read_and_available_len, remaining_content_len @@ -609,12 +634,13 @@ async def write_with_length( # Write data with or without length constraint if remaining_content_len is None: await writer.write(chunk) + total_written_len += chunk_len else: await writer.write(chunk[:remaining_content_len]) + total_written_len += min(remaining_content_len, chunk_len) remaining_content_len -= chunk_len - total_written_len += chunk_len - + self._report_progress(total_written_len) # Check if we're done writing if self._should_stop_writing( available_len, total_written_len, remaining_content_len @@ -877,8 +903,13 @@ async def write_with_length( """ self._set_or_restore_start_position() loop_count = 0 + total_written_len = 0 remaining_bytes = content_length + + self._report_progress(total_written_len) + while chunk := self._value.read(READ_SIZE): + chunk_len = len(chunk) if loop_count > 0: # Avoid blocking the event loop # if they pass a large BytesIO object @@ -887,11 +918,16 @@ async def write_with_length( await asyncio.sleep(0) if remaining_bytes is None: await writer.write(chunk) + total_written_len += chunk_len + self._report_progress(total_written_len) else: await writer.write(chunk[:remaining_bytes]) - remaining_bytes -= len(chunk) + total_written_len += min(remaining_bytes, chunk_len) + self._report_progress(total_written_len) + remaining_bytes -= chunk_len if remaining_bytes <= 0: return + loop_count += 1 async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: @@ -1020,15 +1056,23 @@ async def write_with_length( 4. Does NOT generate cache - that's done by as_bytes() """ + total_written_len = 0 + self._report_progress(total_written_len) + # If we have cached chunks, use them if self._cached_chunks is not None: remaining_bytes = content_length for chunk in self._cached_chunks: + chunk_len = len(chunk) if remaining_bytes is None: await writer.write(chunk) + total_written_len += chunk_len + self._report_progress(total_written_len) elif remaining_bytes > 0: await writer.write(chunk[:remaining_bytes]) - remaining_bytes -= len(chunk) + total_written_len += min(remaining_bytes, chunk_len) + self._report_progress(total_written_len) + remaining_bytes -= chunk_len else: break return @@ -1043,12 +1087,17 @@ async def write_with_length( try: while True: chunk = await anext(self._iter) + chunk_len = len(chunk) if remaining_bytes is None: await writer.write(chunk) + total_written_len += chunk_len + self._report_progress(total_written_len) # If we have a content length limit elif remaining_bytes > 0: await writer.write(chunk[:remaining_bytes]) - remaining_bytes -= len(chunk) + total_written_len += min(remaining_bytes, chunk_len) + self._report_progress(total_written_len) + remaining_bytes -= chunk_len # We still want to exhaust the iterator even # if we have reached the content length limit # since the file handle may not get closed by diff --git a/tests/test_payload.py b/tests/test_payload.py index 205a3efdf81..915cd7a5987 100644 --- a/tests/test_payload.py +++ b/tests/test_payload.py @@ -232,6 +232,27 @@ async def test_bytes_payload_write_with_length_truncated() -> None: assert len(writer.get_written_bytes()) == 5 +async def test_bytes_payload_write_progress_callback() -> None: + """Test BytesPayload writing with progress callback.""" + progress_callback = unittest.mock.Mock() + p = payload.BytesPayload(b"0123456789") + p.set_progress_callback(progress_callback) + writer = MockStreamWriter() + + await p.write_with_length(writer, 5) + assert progress_callback.call_args_list == [ + unittest.mock.call(0), + unittest.mock.call(5), + ] + progress_callback.call_args_list.clear() + + await p.write_with_length(writer, None) + assert progress_callback.call_args_list == [ + unittest.mock.call(0), + unittest.mock.call(10), + ] + + async def test_iobase_payload_write_with_length_no_limit() -> None: """Test IOBasePayload writing with no content length limit.""" data = b"0123456789" @@ -265,6 +286,68 @@ async def test_iobase_payload_write_with_length_truncated() -> None: assert len(writer.get_written_bytes()) == 5 +async def test_iobase_payload_write_progress_callback() -> None: + """Test IOBasePayload writing with progress callback.""" + progress_callback = unittest.mock.Mock() + p = payload.IOBasePayload(io.BytesIO(b"0123456789")) + p.set_progress_callback(progress_callback) + writer = MockStreamWriter() + + await p.write_with_length(writer, 5) + assert progress_callback.call_args_list == [ + unittest.mock.call(0), + unittest.mock.call(5), + ] + progress_callback.call_args_list.clear() + + await p.write_with_length(writer, None) + assert progress_callback.call_args_list == [ + unittest.mock.call(0), + unittest.mock.call(10), + ] + + +@pytest.mark.parametrize( + ("content_length", "expected_calls"), + [ + ( + 6, + [ + unittest.mock.call(0), + unittest.mock.call(4), + unittest.mock.call(6), + ], + ), + ( + 10, + [ + unittest.mock.call(0), + unittest.mock.call(4), + unittest.mock.call(8), + unittest.mock.call(10), + ], + ), + ], +) +async def test_iobase_payload_write_chunked_progress_callback( + content_length: int, expected_calls: list[unittest.mock._Call] +) -> None: + """Test IOBasePayload writing in chunks with progress callback.""" + # Mock the file-like object to track read calls + mock_file = unittest.mock.Mock(spec=io.BytesIO) + mock_file.tell.return_value = 0 + mock_file.fileno.side_effect = AttributeError # Make size return None + mock_file.read.side_effect = [b"0123", b"4567", b"89"] + + progress_callback = unittest.mock.Mock() + p = payload.IOBasePayload(mock_file) + writer = MockStreamWriter() + p.set_progress_callback(progress_callback) + + await p.write_with_length(writer, content_length) + assert progress_callback.call_args_list == expected_calls + + async def test_bytesio_payload_write_with_length_no_limit() -> None: """Test BytesIOPayload writing with no content length limit.""" data = b"0123456789" @@ -350,6 +433,27 @@ async def test_bytesio_payload_remaining_bytes_exhausted() -> None: assert written == data[:8000] +async def test_bytesio_payload_write_progress_callback() -> None: + """Test BytesIOPayload writing with progress callback.""" + progress_callback = unittest.mock.Mock() + p = payload.BytesIOPayload(io.BytesIO(b"0123456789abcdef" * 1000)) + p.set_progress_callback(progress_callback) + writer = MockStreamWriter() + + await p.write_with_length(writer, 5) + assert progress_callback.call_args_list == [ + unittest.mock.call(0), + unittest.mock.call(5), + ] + progress_callback.call_args_list.clear() + + await p.write_with_length(writer, None) + assert progress_callback.call_args_list == [ + unittest.mock.call(0), + unittest.mock.call(16000), + ] + + async def test_iobase_payload_exact_chunk_size_limit() -> None: """Test IOBasePayload with content length matching exactly one read chunk.""" chunk_size = 2**16 # 65536 bytes (READ_SIZE) @@ -576,6 +680,47 @@ async def gen() -> AsyncIterator[bytes]: assert len(writer.get_written_bytes()) == 4 +@pytest.mark.parametrize( + ("content_length", "expected_calls"), + [ + ( + 6, + [ + unittest.mock.call(0), + unittest.mock.call(4), + unittest.mock.call(6), + ], + ), + ( + None, + [ + unittest.mock.call(0), + unittest.mock.call(4), + unittest.mock.call(8), + unittest.mock.call(10), + ], + ), + ], +) +async def test_async_iterable_payload_write_chunked_progress_callback( + content_length: int | None, expected_calls: list[unittest.mock._Call] +) -> None: + """Test AsyncIterablePayload writing with content length truncating mid-chunk.""" + + async def gen() -> AsyncIterator[bytes]: + yield b"0123" + yield b"4567" + yield b"89" + + progress_callback = unittest.mock.Mock() + p = payload.AsyncIterablePayload(gen()) + p.set_progress_callback(progress_callback) + writer = MockStreamWriter() + + await p.write_with_length(writer, content_length) + assert progress_callback.call_args_list == expected_calls + + async def test_bytes_payload_backwards_compatibility() -> None: """Test BytesPayload.write() backwards compatibility delegates to write_with_length().""" p = payload.BytesPayload(b"1234567890")