Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES/12340.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Added the possibility to provide a callback to the ``Payload``,
which is used by their writer methods to report back the already written bytes.
-- by :user:`mib1185`.
61 changes: 55 additions & 6 deletions aiohttp/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys
import warnings
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, AsyncIterator, Iterable
from collections.abc import AsyncIterable, AsyncIterator, Callable, Iterable
from itertools import chain
from typing import IO, Any, Final, TextIO

Expand Down Expand Up @@ -151,11 +151,13 @@ def __init__(
content_type: None | str | _SENTINEL = sentinel,
filename: str | None = None,
encoding: str | None = None,
progress: Callable[[int], None] | None = None,
**kwargs: Any,
) -> None:
self._encoding = encoding
self._filename = filename
self._headers = CIMultiDict[str]()
self._progress = progress
self._value = value
if content_type is not sentinel and content_type is not None:
assert isinstance(content_type, str)
Expand Down Expand Up @@ -240,6 +242,19 @@ def set_content_disposition(
disptype, quote_fields=quote_fields, _charset=_charset, params=params
)

def set_progress_callback(
self, callback: Callable[[int], None] | None = None
) -> None:
"""
Set a callback function to be called with the total number of bytes written so far.

Args:
callback: A callable that takes an integer representing the total number of bytes written so far.
When set to `None`, it will clear any existing progress callback.

"""
self._progress = callback

@abstractmethod
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
"""
Expand Down Expand Up @@ -334,6 +349,11 @@ async def close(self) -> None:
"""
self._close()

def _report_progress(self, total_written_len: int) -> None:
"""Call the progress callback if it is set, with the total number of bytes written so far."""
if self._progress:
self._progress(total_written_len)


class BytesPayload(Payload):
_value: bytes
Expand Down Expand Up @@ -408,10 +428,13 @@ async def write_with_length(
is performed efficiently using array slicing.

"""
self._report_progress(0)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, if it make sense to add the progress callback support to payloads which doesn't support chunked writing?

Copy link
Copy Markdown
Member

@Dreamsorcerer Dreamsorcerer Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also misses the .write() method. I feel like such a thing should probably be in the writer instead..?

I'm also wondering if you need a callback, or whether you can already achieve this using writer.output_size?

And if we do decide to go with a callback, should we put it in writer.drain() so it's updated once the bytes have finished sending, instead of before?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also wondering if you need a callback, or whether you can already achieve this using writer.output_size?

In particular here, I'm trying to think of performance. If we're uploading 1 GB and this callback gets called every 1 KB, that's going to be a lot of CPU churn. In that case, you'd be better just running a task that reads the attribute every second or so.

Copy link
Copy Markdown
Author

@mib1185 mib1185 Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or whether you can already achieve this using writer.output_size?

I'm using the described approach from Sending Multipart Requests in the synology dsm lib here and TBH I've no clue how to reach the mentioned writer.output_size property. Any hint how to reach this property would highly be appreciated and the need of this PR would also be gone 😬

Copy link
Copy Markdown
Member

@Dreamsorcerer Dreamsorcerer Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, looking through, it looks like the only semi-reasonable way to reach it currently would be to use a middleware and extract request._writer from there. Definitely looks like we should do something to expose it more easily..

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also probably possible to use the tracing API and track how much data is sent through on_request_chunk_sent: https://docs.aiohttp.org/en/stable/client_advanced.html#aiohttp-client-tracing

Though still not convinced that's the best option for you.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll get @bdraco's thoughts on it too and see if what plan we can come up with.

if content_length is not None:
await writer.write(self._value[:content_length])
self._report_progress(content_length)
else:
await writer.write(self._value)
self._report_progress(len(self._value))


class StringPayload(BytesPayload):
Expand Down Expand Up @@ -598,6 +621,8 @@ async def write_with_length(
total_written_len = 0
remaining_content_len = content_length

self._report_progress(total_written_len)

# Get initial data and available length
available_len, chunk = await loop.run_in_executor(
None, self._read_and_available_len, remaining_content_len
Expand All @@ -609,12 +634,13 @@ async def write_with_length(
# Write data with or without length constraint
if remaining_content_len is None:
await writer.write(chunk)
total_written_len += chunk_len
else:
await writer.write(chunk[:remaining_content_len])
total_written_len += min(remaining_content_len, chunk_len)
remaining_content_len -= chunk_len

total_written_len += chunk_len

self._report_progress(total_written_len)
# Check if we're done writing
if self._should_stop_writing(
available_len, total_written_len, remaining_content_len
Expand Down Expand Up @@ -877,8 +903,13 @@ async def write_with_length(
"""
self._set_or_restore_start_position()
loop_count = 0
total_written_len = 0
remaining_bytes = content_length

self._report_progress(total_written_len)

while chunk := self._value.read(READ_SIZE):
chunk_len = len(chunk)
if loop_count > 0:
# Avoid blocking the event loop
# if they pass a large BytesIO object
Expand All @@ -887,11 +918,16 @@ async def write_with_length(
await asyncio.sleep(0)
if remaining_bytes is None:
await writer.write(chunk)
total_written_len += chunk_len
self._report_progress(total_written_len)
else:
await writer.write(chunk[:remaining_bytes])
remaining_bytes -= len(chunk)
total_written_len += min(remaining_bytes, chunk_len)
self._report_progress(total_written_len)
remaining_bytes -= chunk_len
if remaining_bytes <= 0:
return

loop_count += 1

async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes:
Expand Down Expand Up @@ -1020,15 +1056,23 @@ async def write_with_length(
4. Does NOT generate cache - that's done by as_bytes()

"""
total_written_len = 0
self._report_progress(total_written_len)

# If we have cached chunks, use them
if self._cached_chunks is not None:
remaining_bytes = content_length
for chunk in self._cached_chunks:
chunk_len = len(chunk)
if remaining_bytes is None:
await writer.write(chunk)
total_written_len += chunk_len
self._report_progress(total_written_len)
elif remaining_bytes > 0:
await writer.write(chunk[:remaining_bytes])
remaining_bytes -= len(chunk)
total_written_len += min(remaining_bytes, chunk_len)
self._report_progress(total_written_len)
remaining_bytes -= chunk_len
else:
break
return
Expand All @@ -1043,12 +1087,17 @@ async def write_with_length(
try:
while True:
chunk = await anext(self._iter)
chunk_len = len(chunk)
if remaining_bytes is None:
await writer.write(chunk)
total_written_len += chunk_len
self._report_progress(total_written_len)
# If we have a content length limit
elif remaining_bytes > 0:
await writer.write(chunk[:remaining_bytes])
remaining_bytes -= len(chunk)
total_written_len += min(remaining_bytes, chunk_len)
self._report_progress(total_written_len)
remaining_bytes -= chunk_len
# We still want to exhaust the iterator even
# if we have reached the content length limit
# since the file handle may not get closed by
Expand Down
145 changes: 145 additions & 0 deletions tests/test_payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,27 @@ async def test_bytes_payload_write_with_length_truncated() -> None:
assert len(writer.get_written_bytes()) == 5


async def test_bytes_payload_write_progress_callback() -> None:
"""Test BytesPayload writing with progress callback."""
progress_callback = unittest.mock.Mock()
p = payload.BytesPayload(b"0123456789")
p.set_progress_callback(progress_callback)
writer = MockStreamWriter()

await p.write_with_length(writer, 5)
assert progress_callback.call_args_list == [
unittest.mock.call(0),
unittest.mock.call(5),
]
progress_callback.call_args_list.clear()

await p.write_with_length(writer, None)
assert progress_callback.call_args_list == [
unittest.mock.call(0),
unittest.mock.call(10),
]


async def test_iobase_payload_write_with_length_no_limit() -> None:
"""Test IOBasePayload writing with no content length limit."""
data = b"0123456789"
Expand Down Expand Up @@ -265,6 +286,68 @@ async def test_iobase_payload_write_with_length_truncated() -> None:
assert len(writer.get_written_bytes()) == 5


async def test_iobase_payload_write_progress_callback() -> None:
"""Test IOBasePayload writing with progress callback."""
progress_callback = unittest.mock.Mock()
p = payload.IOBasePayload(io.BytesIO(b"0123456789"))
p.set_progress_callback(progress_callback)
writer = MockStreamWriter()

await p.write_with_length(writer, 5)
assert progress_callback.call_args_list == [
unittest.mock.call(0),
unittest.mock.call(5),
]
progress_callback.call_args_list.clear()

await p.write_with_length(writer, None)
assert progress_callback.call_args_list == [
unittest.mock.call(0),
unittest.mock.call(10),
]


@pytest.mark.parametrize(
("content_length", "expected_calls"),
[
(
6,
[
unittest.mock.call(0),
unittest.mock.call(4),
unittest.mock.call(6),
],
),
(
10,
[
unittest.mock.call(0),
unittest.mock.call(4),
unittest.mock.call(8),
unittest.mock.call(10),
],
),
],
)
async def test_iobase_payload_write_chunked_progress_callback(
content_length: int, expected_calls: list[unittest.mock._Call]
) -> None:
"""Test IOBasePayload writing in chunks with progress callback."""
# Mock the file-like object to track read calls
mock_file = unittest.mock.Mock(spec=io.BytesIO)
mock_file.tell.return_value = 0
mock_file.fileno.side_effect = AttributeError # Make size return None
mock_file.read.side_effect = [b"0123", b"4567", b"89"]

progress_callback = unittest.mock.Mock()
p = payload.IOBasePayload(mock_file)
writer = MockStreamWriter()
p.set_progress_callback(progress_callback)

await p.write_with_length(writer, content_length)
assert progress_callback.call_args_list == expected_calls


async def test_bytesio_payload_write_with_length_no_limit() -> None:
"""Test BytesIOPayload writing with no content length limit."""
data = b"0123456789"
Expand Down Expand Up @@ -350,6 +433,27 @@ async def test_bytesio_payload_remaining_bytes_exhausted() -> None:
assert written == data[:8000]


async def test_bytesio_payload_write_progress_callback() -> None:
"""Test BytesIOPayload writing with progress callback."""
progress_callback = unittest.mock.Mock()
p = payload.BytesIOPayload(io.BytesIO(b"0123456789abcdef" * 1000))
p.set_progress_callback(progress_callback)
writer = MockStreamWriter()

await p.write_with_length(writer, 5)
assert progress_callback.call_args_list == [
unittest.mock.call(0),
unittest.mock.call(5),
]
progress_callback.call_args_list.clear()

await p.write_with_length(writer, None)
assert progress_callback.call_args_list == [
unittest.mock.call(0),
unittest.mock.call(16000),
]


async def test_iobase_payload_exact_chunk_size_limit() -> None:
"""Test IOBasePayload with content length matching exactly one read chunk."""
chunk_size = 2**16 # 65536 bytes (READ_SIZE)
Expand Down Expand Up @@ -576,6 +680,47 @@ async def gen() -> AsyncIterator[bytes]:
assert len(writer.get_written_bytes()) == 4


@pytest.mark.parametrize(
("content_length", "expected_calls"),
[
(
6,
[
unittest.mock.call(0),
unittest.mock.call(4),
unittest.mock.call(6),
],
),
(
None,
[
unittest.mock.call(0),
unittest.mock.call(4),
unittest.mock.call(8),
unittest.mock.call(10),
],
),
],
)
async def test_async_iterable_payload_write_chunked_progress_callback(
content_length: int | None, expected_calls: list[unittest.mock._Call]
) -> None:
"""Test AsyncIterablePayload writing with content length truncating mid-chunk."""

async def gen() -> AsyncIterator[bytes]:
yield b"0123"
yield b"4567"
yield b"89"

progress_callback = unittest.mock.Mock()
p = payload.AsyncIterablePayload(gen())
p.set_progress_callback(progress_callback)
writer = MockStreamWriter()

await p.write_with_length(writer, content_length)
assert progress_callback.call_args_list == expected_calls


async def test_bytes_payload_backwards_compatibility() -> None:
"""Test BytesPayload.write() backwards compatibility delegates to write_with_length()."""
p = payload.BytesPayload(b"1234567890")
Expand Down
Loading