diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 844a10d2..6cf29e6e 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -10,6 +10,8 @@ The third digit is only for regressions. Backward-incompatible changes: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +- ``OpenSSL.SSL.Connection.set_session`` now raises ``ValueError`` if the ``Session`` was obtained from a ``Connection`` that was using a different ``Context`` than this one. OpenSSL requires (but does not verify) that sessions only be re-used with a compatible ``SSL_CTX``, so this contract is now enforced. + Deprecations: ^^^^^^^^^^^^^ diff --git a/src/OpenSSL/SSL.py b/src/OpenSSL/SSL.py index 90065e95..e951eae6 100644 --- a/src/OpenSSL/SSL.py +++ b/src/OpenSSL/SSL.py @@ -847,6 +847,11 @@ class Session: """ _session: Any + # The Context the Connection this Session came from was using. OpenSSL + # requires that a session only be re-used with a compatible SSL_CTX, but + # doesn't verify it, so we pin the Context here and enforce identity in + # Connection.set_session. + _context: Context F = TypeVar("F", bound=Callable[..., Any]) @@ -3032,12 +3037,18 @@ def get_session(self) -> Session | None: pysession = Session.__new__(Session) pysession._session = _ffi.gc(session, _lib.SSL_SESSION_free) + pysession._context = self._context return pysession def set_session(self, session: Session) -> None: """ Set the session to be used when the TLS/SSL connection is established. + The session must have been obtained, via :meth:`get_session`, from a + :class:`Connection` that was using the same :class:`Context` as this + one. OpenSSL requires (but does not verify) that sessions only be + re-used with a compatible ``SSL_CTX``, so this is enforced here. + :param session: A Session instance representing the session to use. :returns: None @@ -3046,6 +3057,12 @@ def set_session(self, session: Session) -> None: if not isinstance(session, Session): raise TypeError("session must be a Session instance") + if session._context is not self._context: + raise ValueError( + "session must have been created by a Connection using the " + "same Context as this one" + ) + result = _lib.SSL_set_session(self._ssl, session._session) _openssl_assert(result == 1) diff --git a/tests/test_ssl.py b/tests/test_ssl.py index b3da8a6d..1e6939c3 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -3108,12 +3108,22 @@ def makeServer(socket: socket) -> Connection: server.set_accept_state() return server - originalServer, originalClient = loopback(server_factory=makeServer) + clientCtx = Context(SSLv23_METHOD) + + def makeOriginalClient(socket: socket) -> Connection: + client = Connection(clientCtx, socket) + client.set_connect_state() + return client + + originalServer, originalClient = loopback( + server_factory=makeServer, client_factory=makeOriginalClient + ) originalSession = originalClient.get_session() assert originalSession is not None def makeClient(socket: socket) -> Connection: - client = loopback_client_factory(socket) + client = Connection(clientCtx, socket) + client.set_connect_state() client.set_session(originalSession) return client @@ -3129,18 +3139,15 @@ def makeClient(socket: socket) -> Connection: # connections is the same, the session was re-used! assert originalServer.master_key() == resumedServer.master_key() - def test_set_session_wrong_method(self) -> None: + def test_set_session_wrong_context(self) -> None: """ - If `Connection.set_session` is passed a `Session` instance associated - with a context using a different SSL method than the `Connection` - is using, a `OpenSSL.SSL.Error` is raised. + If `Connection.set_session` is passed a `Session` instance that was + created by a `Connection` using a different `Context` than the + `Connection` is using, a `ValueError` is raised. """ - v1 = TLSv1_2_METHOD - v2 = TLSv1_METHOD - key = load_privatekey(FILETYPE_PEM, server_key_pem) cert = load_certificate(FILETYPE_PEM, server_cert_pem) - ctx = Context(v1) + ctx = Context(TLSv1_2_METHOD) ctx.use_privatekey(key) ctx.use_certificate(cert) ctx.set_session_id(b"unity-test") @@ -3150,26 +3157,15 @@ def makeServer(socket: socket) -> Connection: server.set_accept_state() return server - def makeOriginalClient(socket: socket) -> Connection: - client = Connection(Context(v1), socket) - client.set_connect_state() - return client - - _, originalClient = loopback( - server_factory=makeServer, client_factory=makeOriginalClient - ) + _, originalClient = loopback(server_factory=makeServer) originalSession = originalClient.get_session() assert originalSession is not None - def makeClient(socket: socket) -> Connection: - # Intentionally use a different, incompatible method here. - client = Connection(Context(v2), socket) - client.set_connect_state() + # Intentionally use a different Context here. + client = Connection(Context(SSLv23_METHOD), None) + client.set_connect_state() + with pytest.raises(ValueError): client.set_session(originalSession) - return client - - with pytest.raises(Error): - loopback(client_factory=makeClient, server_factory=makeServer) def test_wantWriteError(self) -> None: """