diff --git a/tornado/iostream.py b/tornado/iostream.py index 53e81fff3..e2a833ff8 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -1256,6 +1256,7 @@ def start_tls( ssl_stream._ssl_connect_future = future ssl_stream.max_buffer_size = self.max_buffer_size ssl_stream.read_chunk_size = self.read_chunk_size + future._ssl_stream = ssl_stream return future def _handle_connect(self) -> None: diff --git a/tornado/tcpclient.py b/tornado/tcpclient.py index 04a0c84f9..776493712 100644 --- a/tornado/tcpclient.py +++ b/tornado/tcpclient.py @@ -273,12 +273,16 @@ async def connect( # the same host. (http://tools.ietf.org/html/rfc6555#section-4.2) if ssl_options is not None: if timeout is not None: - stream = await gen.with_timeout( - timeout, - stream.start_tls( - False, ssl_options=ssl_options, server_hostname=host - ), + handshake_future: "Future[IOStream]" = stream.start_tls( + False, ssl_options=ssl_options, server_hostname=host ) + try: + stream = await gen.with_timeout(timeout, handshake_future) + except TimeoutError: + ssl_stream = getattr(handshake_future, "_ssl_stream", None) + if ssl_stream is not None: + ssl_stream.close() + raise else: stream = await stream.start_tls( False, ssl_options=ssl_options, server_hostname=host diff --git a/tornado/test/tcpclient_test.py b/tornado/test/tcpclient_test.py index ffe65d322..e9633ec9c 100644 --- a/tornado/test/tcpclient_test.py +++ b/tornado/test/tcpclient_test.py @@ -14,10 +14,12 @@ # under the License. import getpass import socket +import ssl import typing import unittest from contextlib import closing +from tornado import gen from tornado.concurrent import Future from tornado.gen import TimeoutError from tornado.iostream import IOStream @@ -173,6 +175,68 @@ def resolve(self, *args, **kwargs): "1.2.3.4", 12345, timeout=timeout ) + @gen_test + def test_tls_handshake_timeout_closes_ssl_stream(self): + # Regression test for issue #3614: when TCPClient.connect is + # called with both ssl_options and a timeout, a TLS handshake + # timeout must close the SSLIOStream that owns the underlying + # socket, not leak it. + port = self.start_server(socket.AF_INET) + + original_start_tls = IOStream.start_tls + closed_streams: list[IOStream] = [] + tls_future_holder: list[Future[IOStream]] = [] + + def fake_start_tls( + self: IOStream, + server_side: bool, + ssl_options: typing.Any = None, + server_hostname: typing.Any = None, + ) -> Future[IOStream]: + from tornado.iostream import SSLIOStream + + real_socket = self.socket + self.io_loop.remove_handler(real_socket) + self.socket = None # type: ignore[assignment] + ssl_stream = SSLIOStream(real_socket, ssl_options=ssl_options) + original = self._close_callback + if original is not None: + ssl_stream.set_close_callback(original) + real_close = ssl_stream.close + + def tracking_close(*args: typing.Any, **kwargs: typing.Any) -> None: + closed_streams.append(ssl_stream) + real_close(*args, **kwargs) + + ssl_stream.close = tracking_close # type: ignore[method-assign] + tls_future: Future[IOStream] = Future() + tls_future_holder.append(tls_future) + return tls_future + + IOStream.start_tls = fake_start_tls # type: ignore[method-assign] + try: + with self.assertRaises(TimeoutError): + yield self.client.connect( + "127.0.0.1", + port, + ssl_options=dict(cert_reqs=ssl.CERT_NONE), + timeout=0.05, + ) + # The handshake future is still pending. Resolve it so the + # cleanup callback registered by the timeout handler runs + # and closes the SSLIOStream that owns the socket. + self.assertEqual(len(tls_future_holder), 1) + tls_future_holder[0].set_result(None) # type: ignore[arg-type] + yield gen.sleep(0) + finally: + IOStream.start_tls = original_start_tls # type: ignore[method-assign] + self.assertEqual( + len(closed_streams), + 1, + "SSLIOStream owning the underlying socket was not closed on " + "TLS handshake timeout (issue #3614)", + ) + class TestConnectorSplit(unittest.TestCase): def test_one_family(self):