Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ The third digit is only for regressions.
Backward-incompatible changes:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

- ``OpenSSL.SSL.Connection.set_session`` now raises ``ValueError`` if the ``Session`` was obtained from a ``Connection`` that was using a different ``Context`` than this one. OpenSSL requires (but does not verify) that sessions only be re-used with a compatible ``SSL_CTX``, so this contract is now enforced.

Deprecations:
^^^^^^^^^^^^^

Expand Down
17 changes: 17 additions & 0 deletions src/OpenSSL/SSL.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,11 @@ class Session:
"""

_session: Any
# The Context the Connection this Session came from was using. OpenSSL
# requires that a session only be re-used with a compatible SSL_CTX, but
# doesn't verify it, so we pin the Context here and enforce identity in
# Connection.set_session.
_context: Context


F = TypeVar("F", bound=Callable[..., Any])
Expand Down Expand Up @@ -3032,12 +3037,18 @@ def get_session(self) -> Session | None:

pysession = Session.__new__(Session)
pysession._session = _ffi.gc(session, _lib.SSL_SESSION_free)
pysession._context = self._context
return pysession

def set_session(self, session: Session) -> None:
"""
Set the session to be used when the TLS/SSL connection is established.

The session must have been obtained, via :meth:`get_session`, from a
:class:`Connection` that was using the same :class:`Context` as this
one. OpenSSL requires (but does not verify) that sessions only be
re-used with a compatible ``SSL_CTX``, so this is enforced here.

:param session: A Session instance representing the session to use.
:returns: None

Expand All @@ -3046,6 +3057,12 @@ def set_session(self, session: Session) -> None:
if not isinstance(session, Session):
raise TypeError("session must be a Session instance")

if session._context is not self._context:
raise ValueError(
"session must have been created by a Connection using the "
"same Context as this one"
)

result = _lib.SSL_set_session(self._ssl, session._session)
_openssl_assert(result == 1)

Expand Down
48 changes: 22 additions & 26 deletions tests/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3108,12 +3108,22 @@ def makeServer(socket: socket) -> Connection:
server.set_accept_state()
return server

originalServer, originalClient = loopback(server_factory=makeServer)
clientCtx = Context(SSLv23_METHOD)

def makeOriginalClient(socket: socket) -> Connection:
client = Connection(clientCtx, socket)
client.set_connect_state()
return client

originalServer, originalClient = loopback(
server_factory=makeServer, client_factory=makeOriginalClient
)
originalSession = originalClient.get_session()
assert originalSession is not None

def makeClient(socket: socket) -> Connection:
client = loopback_client_factory(socket)
client = Connection(clientCtx, socket)
client.set_connect_state()
client.set_session(originalSession)
return client

Expand All @@ -3129,18 +3139,15 @@ def makeClient(socket: socket) -> Connection:
# connections is the same, the session was re-used!
assert originalServer.master_key() == resumedServer.master_key()

def test_set_session_wrong_method(self) -> None:
def test_set_session_wrong_context(self) -> None:
"""
If `Connection.set_session` is passed a `Session` instance associated
with a context using a different SSL method than the `Connection`
is using, a `OpenSSL.SSL.Error` is raised.
If `Connection.set_session` is passed a `Session` instance that was
created by a `Connection` using a different `Context` than the
`Connection` is using, a `ValueError` is raised.
"""
v1 = TLSv1_2_METHOD
v2 = TLSv1_METHOD

key = load_privatekey(FILETYPE_PEM, server_key_pem)
cert = load_certificate(FILETYPE_PEM, server_cert_pem)
ctx = Context(v1)
ctx = Context(TLSv1_2_METHOD)
ctx.use_privatekey(key)
ctx.use_certificate(cert)
ctx.set_session_id(b"unity-test")
Expand All @@ -3150,26 +3157,15 @@ def makeServer(socket: socket) -> Connection:
server.set_accept_state()
return server

def makeOriginalClient(socket: socket) -> Connection:
client = Connection(Context(v1), socket)
client.set_connect_state()
return client

_, originalClient = loopback(
server_factory=makeServer, client_factory=makeOriginalClient
)
_, originalClient = loopback(server_factory=makeServer)
originalSession = originalClient.get_session()
assert originalSession is not None

def makeClient(socket: socket) -> Connection:
# Intentionally use a different, incompatible method here.
client = Connection(Context(v2), socket)
client.set_connect_state()
# Intentionally use a different Context here.
client = Connection(Context(SSLv23_METHOD), None)
client.set_connect_state()
with pytest.raises(ValueError):
client.set_session(originalSession)
return client

with pytest.raises(Error):
loopback(client_factory=makeClient, server_factory=makeServer)

def test_wantWriteError(self) -> None:
"""
Expand Down
Loading