Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions src/neo4j/_async_compat/network/_bolt_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _non_expired_timeout(
class AsyncBoltSocketBase(abc.ABC):
Bolt: te.Final[type[AsyncBolt]] = None # type: ignore[assignment]

def __init__(self, reader, protocol, writer) -> None:
def __init__(self, reader, protocol, writer, sockname, peername) -> None:
self._reader = reader # type: asyncio.StreamReader
self._protocol = protocol # type: asyncio.StreamReaderProtocol
self._writer = writer # type: asyncio.StreamWriter
Expand All @@ -109,6 +109,8 @@ def __init__(self, reader, protocol, writer) -> None:
# int - seconds to wait for data
self._timeout: float | None = None
self._deadline: Deadline | None = None
self._sockname = sockname
self._peername = peername

async def _wait_for_io(
self,
Expand Down Expand Up @@ -157,10 +159,10 @@ def _socket(self) -> socket:
return self._writer.transport.get_extra_info("socket")

def getsockname(self):
return self._writer.transport.get_extra_info("sockname")
return self._sockname

def getpeername(self):
return self._writer.transport.get_extra_info("peername")
return self._peername

def getpeercert(self, *args, **kwargs):
return self._writer.transport.get_extra_info("ssl_object").getpeercert(
Expand Down Expand Up @@ -230,7 +232,10 @@ async def _connect_secure(
if timeout == 0: # socket timeout of 0 => non-blocking
timeout = None
await wait_for(loop.sock_connect(s, resolved_address), timeout)
local_port = s.getsockname()[1]

sockname = s.getsockname()
peername = s.getpeername()
local_port = sockname[1]

keep_alive = 1 if keep_alive else 0
s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive)
Expand Down Expand Up @@ -271,14 +276,13 @@ async def _connect_secure(
"ssl_object"
).getpeercert(binary_form=True)
if der_encoded_server_certificate is None:
local_port = s.getsockname()[1]
raise BoltProtocolError(
"When using an encrypted socket, the server should "
"always provide a certificate",
address=(resolved_address._host_name, local_port),
)

return cls(reader, protocol, writer)
return cls(reader, protocol, writer, sockname, peername)

except asyncio.TimeoutError:
log.debug("[#0000] S: <TIMEOUT> %s", resolved_address)
Expand Down Expand Up @@ -363,8 +367,10 @@ def _kill_raw_socket(cls, socket_):
class BoltSocketBase:
Bolt: te.Final[type[Bolt]] = None # type: ignore[assignment]

def __init__(self, socket_: socket):
def __init__(self, socket_: socket, sockname, peername):
self._socket = socket_
self._sockname = sockname
self._peername = peername
self._deadline: Deadline | None = None

@property
Expand All @@ -374,21 +380,23 @@ def _socket(self) -> socket | SSLSocket:
@_socket.setter
def _socket(self, socket_: socket | SSLSocket) -> None:
self.__socket = socket_
self.getsockname = socket_.getsockname
self.getpeername = socket_.getpeername
if hasattr(socket, "getpeercert"):
self.getpeercert = t.cast(SSLSocket, socket_).getpeercert
elif "getpeercert" in self.__dict__:
del self.__dict__["getpeercert"]
self.gettimeout = socket_.gettimeout
self.settimeout = socket_.settimeout

getsockname: t.Callable = None # type: ignore
getpeername: t.Callable = None # type: ignore
getpeercert: t.Callable = None # type: ignore
gettimeout: t.Callable = None # type: ignore
settimeout: t.Callable = None # type: ignore

def getsockname(self):
return self._sockname

def getpeername(self):
return self._peername

def _wait_for_io(
self,
func: t.Callable[_P, t.Any],
Expand Down Expand Up @@ -488,7 +496,9 @@ def _connect_secure(
) from error
raise

local_port = s.getsockname()[1]
sockname = s.getsockname()
peername = s.getpeername()
local_port = sockname[1]
# Secure the connection if an SSL context has been provided
if ssl_context:
hostname = resolved_address._host_name or None
Expand Down Expand Up @@ -529,7 +539,7 @@ def _connect_secure(
cls._kill_raw_socket(s)
raise

return cls(s)
return cls(s, sockname, peername)

@abc.abstractmethod
def _handshake(self, resolved_address, deadline): ...
Expand Down
19 changes: 7 additions & 12 deletions tests/unit/fixtures/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,23 +71,18 @@ async def drain():
bytes_written.extend(write_buffer)
write_buffer.clear()

def transport_get_extra(key):
if key == "sockname":
return "localhost", 0x1234
if key == "peername":
return "peer_name"
raise KeyError(f"not mocked: {key}")

reader = mocker.Mock(spec=asyncio.StreamReader)
writer = mocker.Mock(spec=asyncio.StreamWriter)
protocol = mocker.Mock(spec=asyncio.StreamReaderProtocol)

reader.read.side_effect = read
writer.write.side_effect = write
writer.drain.side_effect = drain
writer.transport.get_extra_info.side_effect = transport_get_extra

return AsyncBoltSocket(reader, protocol, writer)
sockname = "localhost", 0x1234
peername = "peer_name"

return AsyncBoltSocket(reader, protocol, writer, sockname, peername)

return factory

Expand Down Expand Up @@ -120,10 +115,10 @@ def send_all(b):
socket_mock.recv.side_effect = recv
socket_mock.recv_into.side_effect = recv_into
socket_mock.sendall.side_effect = send_all
socket_mock.getsockname.return_value = ("localhost", 0x1234)
socket_mock.getpeername.return_value = "peer_name"
sockname = "localhost", 0x1234
peername = "peer_name"
socket_mock.gettimeout.return_value = None

return BoltSocket(socket_mock)
return BoltSocket(socket_mock, sockname, peername)

return factory
6 changes: 5 additions & 1 deletion tests/unit/mixed/async_compat/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ def factory():
def socket_factory(reader_factory, writer_factory):
def factory():
protocol = None
return AsyncBoltSocket(reader_factory(), protocol, writer_factory())
sockname = "localhost", 0x1234
peername = "peer_name"
return AsyncBoltSocket(
reader_factory(), protocol, writer_factory(), sockname, peername
)

return factory

Expand Down
Loading