From 734254a3c23d72d88c6ac50d9f3bcb090e6200de Mon Sep 17 00:00:00 2001 From: Yuval Elbar Date: Fri, 17 Apr 2026 08:45:09 +0300 Subject: [PATCH] Fix path-traversal vulnerability in emergency P2P checkpoint service A malicious or compromised peer on the P2P network could supply a manifest whose rel_path contained '..' segments or an absolute path, causing P2PNode.fetch_shard_from_peer() to write attacker-controlled bytes outside the staging directory (e.g. a .pth file in site-packages, yielding persistent RCE on the training host). - Add _safe_path_join() which joins a peer-supplied relative path onto a base directory only if the resolved result stays inside that base. Resolution goes through os.path.realpath so symlink-escape attempts are caught as well. - Apply the helper on both sides of the wire: * Client: fetch_shard_from_peer() validates every manifest entry against stage_dir and aborts the whole fetch on any unsafe entry. * Server: handle_download() replaces the substring '..' check with the same resolve-based containment check against self.directory. - Log every rejection with peer and request context. - Add regression tests for the helper and both call sites. Reported via the Google OSS VRP. --- .../experimental/emergency/p2p/service.py | 55 +++++++- .../emergency/p2p/service_test.py | 130 ++++++++++++++++++ 2 files changed, 182 insertions(+), 3 deletions(-) diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/service.py b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/service.py index 8c8e83117..9ed417d4c 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/service.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/service.py @@ -16,6 +16,7 @@ import concurrent.futures import functools +import os import shutil import socket import socketserver @@ -31,6 +32,43 @@ from orbax.checkpoint.experimental.emergency.p2p import utils +def _safe_path_join(base: epath.Path, rel_path: str) -> epath.Path | None: + """Joins ``rel_path`` onto ``base`` iff the result stays within ``base``. + + Rejects empty strings, absolute paths (including Windows drive-prefixed + paths), and any path whose resolved form escapes ``base`` via ``..`` + components or symlinks. Returns ``None`` for unsafe inputs so callers can + fail closed. + + Args: + base: The base directory that the resulting path must be contained within. + rel_path: A peer-supplied relative path. Treated as untrusted. + + Returns: + ``base / rel_path`` on success, or ``None`` if ``rel_path`` is unsafe. + """ + if not rel_path: + return None + # Reject absolute paths. ``os.path.isabs`` covers POSIX roots and (on + # Windows) drive-prefixed paths; ``splitdrive`` catches e.g. ``C:foo`` which + # is technically relative but still anchored off ``base``. + if os.path.isabs(rel_path) or os.path.splitdrive(rel_path)[0]: + return None + try: + base_real = os.path.realpath(str(base)) + candidate_real = os.path.realpath(os.path.join(base_real, rel_path)) + except OSError: + return None + try: + common = os.path.commonpath([base_real, candidate_real]) + except ValueError: + # Different drives on Windows, or otherwise incomparable. + return None + if common != base_real: + return None + return base / rel_path + + class _ThreadingTCPServer(socketserver.ThreadingTCPServer): """A ThreadingTCPServer that holds a reference to a P2PNode.""" @@ -206,12 +244,12 @@ def handle_download(self, sock, payload: dict[str, Any]): """ rel_path_str = payload.get('rel_path') - if not rel_path_str or '..' in rel_path_str or rel_path_str.startswith('/'): + full_path = _safe_path_join(self.directory, rel_path_str) if rel_path_str else None + if full_path is None: logging.error('Blocked unsafe P2P path request: %s', rel_path_str) protocol.TCPMessage.send_file(sock, epath.Path('__INVALID__')) return - full_path = self.directory / rel_path_str if full_path.exists() and full_path.is_file(): protocol.TCPMessage.send_file(sock, full_path) else: @@ -268,7 +306,18 @@ def fetch_shard_from_peer( futures = [] for f_meta in manifest: rel_path_str = f_meta['rel_path'] - dest_path = stage_dir / rel_path_str + dest_path = _safe_path_join(stage_dir, rel_path_str) + if dest_path is None: + logging.error( + 'Rejecting unsafe manifest entry from peer %s:%d for' + ' step=%d, process_index=%d: %r', + ip, + port, + step, + stored_process_index, + rel_path_str, + ) + return False futures.append( exc.submit( protocol.TCPClient.download, ip, port, rel_path_str, dest_path diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/service_test.py b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/service_test.py index e1b7f3989..33c7d3da2 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/p2p/service_test.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/p2p/service_test.py @@ -15,6 +15,7 @@ """Unit tests for P2PNode service.""" import functools +import os import threading from unittest import mock @@ -23,6 +24,135 @@ from orbax.checkpoint.experimental.emergency.p2p import service +class SafePathJoinTest(absltest.TestCase): + """Tests for the ``_safe_path_join`` path-containment helper.""" + + def setUp(self): + super().setUp() + self.base = epath.Path(self.create_tempdir().full_path) + + def test_valid_relative_path(self): + result = service._safe_path_join(self.base, '1/file1') + self.assertIsNotNone(result) + self.assertEqual(str(result), str(self.base / '1/file1')) + + def test_valid_nested_path(self): + result = service._safe_path_join(self.base, '1/subdir/file2') + self.assertIsNotNone(result) + + def test_empty_rel_path_rejected(self): + self.assertIsNone(service._safe_path_join(self.base, '')) + + def test_absolute_posix_path_rejected(self): + self.assertIsNone(service._safe_path_join(self.base, '/etc/passwd')) + + def test_parent_traversal_rejected(self): + self.assertIsNone(service._safe_path_join(self.base, '../secret')) + + def test_deep_parent_traversal_rejected(self): + self.assertIsNone( + service._safe_path_join(self.base, 'a/b/../../../etc/passwd') + ) + + def test_traversal_to_siblings_rejected(self): + # Reference the base directory's sibling by name; must still be rejected. + sibling = os.path.basename(os.path.dirname(str(self.base))) + '_evil' + self.assertIsNone(service._safe_path_join(self.base, f'../{sibling}/x')) + + def test_symlink_escape_rejected(self): + # A symlink inside ``base`` that points outside ``base`` must not be + # writable through ``_safe_path_join``. + outside = epath.Path(self.create_tempdir().full_path) + link_dir = self.base / 'link' + try: + os.symlink(str(outside), str(link_dir)) + except (OSError, NotImplementedError): + self.skipTest('Symlinks not supported on this platform') + self.assertIsNone(service._safe_path_join(self.base, 'link/pwn')) + + +class PathTraversalRegressionTest(absltest.TestCase): + """Regression tests for CVE-class path-traversal in the P2P service.""" + + def setUp(self): + super().setUp() + self.temp_dir = epath.Path(self.create_tempdir().full_path) + self.enter_context( + mock.patch.object(service, '_ThreadingTCPServer', autospec=True) + ) + server = service._ThreadingTCPServer.return_value + server.server_address = ('localhost', 12345) + self.enter_context( + mock.patch.object(service.multihost, 'process_index', return_value=0) + ) + self.enter_context( + mock.patch.object( + service.socket, + 'getaddrinfo', + return_value=[(service.socket.AF_INET, 0, 0, '', ('127.0.0.1', 0))], + ) + ) + self.node = service.P2PNode(directory=self.temp_dir) + + @mock.patch.object(service.protocol.TCPMessage, 'send_file', autospec=True) + def test_handle_download_blocks_traversal(self, mock_send_file): + """Server must reject ``..`` components regardless of form.""" + sock = mock.Mock() + for bad in [ + '../unsafe', + 'a/b/../../../etc/passwd', + '/etc/passwd', + '', + ]: + mock_send_file.reset_mock() + self.node.handle_download(sock, {'rel_path': bad}) + mock_send_file.assert_called_once_with(sock, epath.Path('__INVALID__')) + + @mock.patch.object(service.shutil, 'rmtree', autospec=True) + @mock.patch.object(service.shutil, 'move', autospec=True) + @mock.patch.object(service.time, 'time', autospec=True) + @mock.patch.object(service.protocol.TCPClient, 'request', autospec=True) + @mock.patch.object(service.protocol.TCPClient, 'download', autospec=True) + def test_fetch_shard_from_peer_rejects_malicious_manifest( + self, + mock_download, + mock_request, + unused_mock_time, + unused_mock_move, + unused_mock_rmtree, + ): + """Client must refuse to download peer-supplied traversal paths. + + A malicious peer returns a manifest whose ``rel_path`` tries to escape the + staging directory (e.g., writing a ``.pth`` file into site-packages). + ``fetch_shard_from_peer`` must abort before any ``download`` call. + """ + mock_request.return_value = [ + {'rel_path': '../../evil.pth', 'size': 10}, + ] + self.assertFalse(self.node.fetch_shard_from_peer('peer', 123, 1, 10)) + mock_download.assert_not_called() + + @mock.patch.object(service.shutil, 'rmtree', autospec=True) + @mock.patch.object(service.shutil, 'move', autospec=True) + @mock.patch.object(service.time, 'time', autospec=True) + @mock.patch.object(service.protocol.TCPClient, 'request', autospec=True) + @mock.patch.object(service.protocol.TCPClient, 'download', autospec=True) + def test_fetch_shard_from_peer_rejects_absolute_path( + self, + mock_download, + mock_request, + unused_mock_time, + unused_mock_move, + unused_mock_rmtree, + ): + mock_request.return_value = [ + {'rel_path': '/etc/passwd', 'size': 10}, + ] + self.assertFalse(self.node.fetch_shard_from_peer('peer', 123, 1, 10)) + mock_download.assert_not_called() + + class NodeHandlerTest(absltest.TestCase): def setUp(self):