From f1ec1fa16108678a8c3ac6dbdd8c3c06d76aa0df Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 5 Jun 2026 14:24:59 +0200 Subject: [PATCH 1/2] Cache sockname and peername Async ----- Instead of relying on asyncio's caching of these properties, we do it ourselves. The advantage is that asyncio populates the cache on a best-effort basis. This can lead to the values being `None` if retrieving them causes an `OSError`. This is a misalignment between the async and sync driver that this PR aims to remedy. Sync ---- While in then async driver we're introducing (custom) caching where implicit caching was already in place, we also introduce caching of these fields in the sync driver for parity. --- .../_async_compat/network/_bolt_socket.py | 37 ++++++++++++------- tests/unit/fixtures/socket.py | 19 ++++------ tests/unit/mixed/async_compat/test_network.py | 6 ++- 3 files changed, 36 insertions(+), 26 deletions(-) diff --git a/src/neo4j/_async_compat/network/_bolt_socket.py b/src/neo4j/_async_compat/network/_bolt_socket.py index 42f35afbd..a1ee19022 100644 --- a/src/neo4j/_async_compat/network/_bolt_socket.py +++ b/src/neo4j/_async_compat/network/_bolt_socket.py @@ -100,7 +100,7 @@ def _non_expired_timeout( class AsyncBoltSocketBase(abc.ABC): Bolt: te.Final[type[AsyncBolt]] = None # type: ignore[assignment] - def __init__(self, reader, protocol, writer) -> None: + def __init__(self, reader, protocol, writer, sockname, peername) -> None: self._reader = reader # type: asyncio.StreamReader self._protocol = protocol # type: asyncio.StreamReaderProtocol self._writer = writer # type: asyncio.StreamWriter @@ -109,6 +109,8 @@ def __init__(self, reader, protocol, writer) -> None: # int - seconds to wait for data self._timeout: float | None = None self._deadline: Deadline | None = None + self._sockname = sockname + self._peername = peername async def _wait_for_io( self, @@ -157,10 +159,10 @@ def _socket(self) -> socket: return self._writer.transport.get_extra_info("socket") def getsockname(self): - return self._writer.transport.get_extra_info("sockname") + return self._sockname def getpeername(self): - return self._writer.transport.get_extra_info("peername") + return self._peername def getpeercert(self, *args, **kwargs): return self._writer.transport.get_extra_info("ssl_object").getpeercert( @@ -230,7 +232,10 @@ async def _connect_secure( if timeout == 0: # socket timeout of 0 => non-blocking timeout = None await wait_for(loop.sock_connect(s, resolved_address), timeout) - local_port = s.getsockname()[1] + + sockname = s.getsockname() + peername = s.getpeername() + local_port = sockname[1] keep_alive = 1 if keep_alive else 0 s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive) @@ -271,14 +276,13 @@ async def _connect_secure( "ssl_object" ).getpeercert(binary_form=True) if der_encoded_server_certificate is None: - local_port = s.getsockname()[1] raise BoltProtocolError( "When using an encrypted socket, the server should " "always provide a certificate", address=(resolved_address._host_name, local_port), ) - return cls(reader, protocol, writer) + return cls(reader, protocol, writer, sockname, peername) except asyncio.TimeoutError: log.debug("[#0000] S: %s", resolved_address) @@ -363,8 +367,10 @@ def _kill_raw_socket(cls, socket_): class BoltSocketBase: Bolt: te.Final[type[Bolt]] = None # type: ignore[assignment] - def __init__(self, socket_: socket): + def __init__(self, socket_: socket, sockname, peername): self._socket = socket_ + self._sockname = sockname + self._peername = peername self._deadline: Deadline | None = None @property @@ -374,21 +380,24 @@ def _socket(self) -> socket | SSLSocket: @_socket.setter def _socket(self, socket_: socket | SSLSocket) -> None: self.__socket = socket_ - self.getsockname = socket_.getsockname - self.getpeername = socket_.getpeername if hasattr(socket, "getpeercert"): self.getpeercert = t.cast(SSLSocket, socket_).getpeercert elif "getpeercert" in self.__dict__: del self.__dict__["getpeercert"] + socket_.getsockname() self.gettimeout = socket_.gettimeout self.settimeout = socket_.settimeout - getsockname: t.Callable = None # type: ignore - getpeername: t.Callable = None # type: ignore getpeercert: t.Callable = None # type: ignore gettimeout: t.Callable = None # type: ignore settimeout: t.Callable = None # type: ignore + def getsockname(self): + return self._sockname + + def getpeername(self): + return self._peername + def _wait_for_io( self, func: t.Callable[_P, t.Any], @@ -488,7 +497,9 @@ def _connect_secure( ) from error raise - local_port = s.getsockname()[1] + sockname = s.getsockname() + peername = s.getpeername() + local_port = sockname[1] # Secure the connection if an SSL context has been provided if ssl_context: hostname = resolved_address._host_name or None @@ -529,7 +540,7 @@ def _connect_secure( cls._kill_raw_socket(s) raise - return cls(s) + return cls(s, sockname, peername) @abc.abstractmethod def _handshake(self, resolved_address, deadline): ... diff --git a/tests/unit/fixtures/socket.py b/tests/unit/fixtures/socket.py index 91c8221b5..bf23289c4 100644 --- a/tests/unit/fixtures/socket.py +++ b/tests/unit/fixtures/socket.py @@ -71,13 +71,6 @@ async def drain(): bytes_written.extend(write_buffer) write_buffer.clear() - def transport_get_extra(key): - if key == "sockname": - return "localhost", 0x1234 - if key == "peername": - return "peer_name" - raise KeyError(f"not mocked: {key}") - reader = mocker.Mock(spec=asyncio.StreamReader) writer = mocker.Mock(spec=asyncio.StreamWriter) protocol = mocker.Mock(spec=asyncio.StreamReaderProtocol) @@ -85,9 +78,11 @@ def transport_get_extra(key): reader.read.side_effect = read writer.write.side_effect = write writer.drain.side_effect = drain - writer.transport.get_extra_info.side_effect = transport_get_extra - return AsyncBoltSocket(reader, protocol, writer) + sockname = "localhost", 0x1234 + peername = "peer_name" + + return AsyncBoltSocket(reader, protocol, writer, sockname, peername) return factory @@ -120,10 +115,10 @@ def send_all(b): socket_mock.recv.side_effect = recv socket_mock.recv_into.side_effect = recv_into socket_mock.sendall.side_effect = send_all - socket_mock.getsockname.return_value = ("localhost", 0x1234) - socket_mock.getpeername.return_value = "peer_name" + sockname = "localhost", 0x1234 + peername = "peer_name" socket_mock.gettimeout.return_value = None - return BoltSocket(socket_mock) + return BoltSocket(socket_mock, sockname, peername) return factory diff --git a/tests/unit/mixed/async_compat/test_network.py b/tests/unit/mixed/async_compat/test_network.py index 1b2677458..c014a4dd0 100644 --- a/tests/unit/mixed/async_compat/test_network.py +++ b/tests/unit/mixed/async_compat/test_network.py @@ -84,7 +84,11 @@ def factory(): def socket_factory(reader_factory, writer_factory): def factory(): protocol = None - return AsyncBoltSocket(reader_factory(), protocol, writer_factory()) + sockname = "localhost", 0x1234 + peername = "peer_name" + return AsyncBoltSocket( + reader_factory(), protocol, writer_factory(), sockname, peername + ) return factory From 32c762f46e28a8cb9cf64deed6a80b2d57e4a0c3 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 8 Jun 2026 12:08:56 +0200 Subject: [PATCH 2/2] Clean-up --- src/neo4j/_async_compat/network/_bolt_socket.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/neo4j/_async_compat/network/_bolt_socket.py b/src/neo4j/_async_compat/network/_bolt_socket.py index a1ee19022..59a28fae8 100644 --- a/src/neo4j/_async_compat/network/_bolt_socket.py +++ b/src/neo4j/_async_compat/network/_bolt_socket.py @@ -384,7 +384,6 @@ def _socket(self, socket_: socket | SSLSocket) -> None: self.getpeercert = t.cast(SSLSocket, socket_).getpeercert elif "getpeercert" in self.__dict__: del self.__dict__["getpeercert"] - socket_.getsockname() self.gettimeout = socket_.gettimeout self.settimeout = socket_.settimeout