Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion cheroot/makefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

# prefer slower Python-based io module
import _pyio as io
import select
import socket


# Write only 16K at a time to sockets
SOCK_WRITE_BLOCKSIZE = 16384

# Seconds to wait for a blocked socket to become writable
SOCK_WRITE_TIMEOUT = 10


class BufferedWriter(io.BufferedWriter):
"""Faux file object attached to a socket object."""
Expand All @@ -26,12 +30,28 @@ def write(self, b):
def _flush_unlocked(self):
self._checkClosed('flush of closed file')
while self._write_buf:
n = None
try:
# ssl sockets only except 'bytes', not bytearrays
# so perhaps we should conditionally wrap this for perf?
n = self.raw.write(bytes(self._write_buf))
n = self.raw.write(
bytes(self._write_buf[:SOCK_WRITE_BLOCKSIZE]),
)
except io.BlockingIOError as e:
n = e.characters_written
if n is None:
_, writable, _ = select.select(
[],
[self.raw],
[],
SOCK_WRITE_TIMEOUT,
)
if not writable:
raise io.BlockingIOError(
0,
'raw stream blocked; no bytes written',
)
continue
del self._write_buf[:n]


Expand Down
1 change: 1 addition & 0 deletions cheroot/makefile.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io

SOCK_WRITE_BLOCKSIZE: int
SOCK_WRITE_TIMEOUT: int

class BufferedWriter(io.BufferedWriter):
def write(self, b): ...
Expand Down
97 changes: 97 additions & 0 deletions cheroot/test/test_makefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,100 @@ def test_bytes_written():
wfile = makefile.MakeFile(sock, 'w')
wfile.write(b'bar')
assert wfile.bytes_written == 3


class _RawWriteBlockOnce:
"""Mock raw.write() returning None once, then writing normally."""

def __init__(self):
"""Initialize _RawWriteBlockOnce."""
self.call_count = 0
self.written = bytearray()

def __call__(self, chunk):
"""Return None on first call to simulate a blocked write."""
self.call_count += 1
if self.call_count == 1:
return None
self.written.extend(chunk)
return len(chunk)

def fileno(self):
"""Return a fake fd for select()."""
return -1


class _RawWriteBlockAlways:
"""Mock raw.write() that always returns None."""

def __init__(self):
"""Initialize _RawWriteBlockAlways."""
self.call_count = 0

def __call__(self, chunk):
"""Return None to simulate a permanently blocked socket."""
self.call_count += 1

def fileno(self):
"""Return a fake fd for select()."""
return -1


def test_flush_recovers_from_temporary_block(monkeypatch):
"""_flush_unlocked() retries after select when raw.write() returns None.

A temporarily blocked socket should recover once select() reports
the socket is writable again, delivering all buffered data.
"""
data = b'x' * (makefile.SOCK_WRITE_BLOCKSIZE * 2)

sock = MockSocket()
wfile = makefile.MakeFile(sock, 'w')
wfile._write_buf.extend(data)

mock = _RawWriteBlockOnce()
wfile.raw.write = mock

# select() reports writable immediately
monkeypatch.setattr(
'cheroot.makefile.select.select',
lambda _rlist, wlist, _xlist, _timeout: ([], wlist, []),
)
wfile._flush_unlocked()

assert bytes(mock.written) == data, (
'all buffered data should be written after select retry'
)


def test_flush_raises_on_sustained_block(monkeypatch):
"""_flush_unlocked() raises BlockingIOError after select timeout.

If the socket stays blocked past SOCK_WRITE_TIMEOUT, the write
buffer must be preserved and BlockingIOError raised.
"""
import io

import pytest

data = b'x' * makefile.SOCK_WRITE_BLOCKSIZE

sock = MockSocket()
wfile = makefile.MakeFile(sock, 'w')
wfile._write_buf.extend(data)

mock = _RawWriteBlockAlways()
wfile.raw.write = mock

# select() reports not writable (timeout)
monkeypatch.setattr(
'cheroot.makefile.select.select',
lambda _rlist, _wlist, _xlist, _timeout: ([], [], []),
)

with pytest.raises(io.BlockingIOError):
wfile._flush_unlocked()

assert len(wfile._write_buf) == len(data), (
'write buffer must be preserved when socket stays blocked'
)
3 changes: 3 additions & 0 deletions docs/changelog-fragments.d/822.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fixed a bug that could cause premature clearing of the write buffer when a socket write is blocked.

-- by :user:`cbbm142`
3 changes: 3 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ backports
bugfixes
builtin
b'xb
buf
compat
config
conftest
Expand All @@ -22,6 +23,7 @@ hardcoded
hostname
inclusivity
intersphinx
io
iterable
linter
linters
Expand All @@ -48,6 +50,7 @@ preconfigure
py
pytest
pythonic
RawIOBase
readonly
rebase
Refactor
Expand Down
Loading