Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 52 additions & 3 deletions checkpoint/orbax/checkpoint/experimental/emergency/p2p/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import concurrent.futures
import functools
import os
import shutil
import socket
import socketserver
Expand All @@ -31,6 +32,43 @@
from orbax.checkpoint.experimental.emergency.p2p import utils


def _safe_path_join(base: epath.Path, rel_path: str) -> epath.Path | None:
"""Joins ``rel_path`` onto ``base`` iff the result stays within ``base``.

Rejects empty strings, absolute paths (including Windows drive-prefixed
paths), and any path whose resolved form escapes ``base`` via ``..``
components or symlinks. Returns ``None`` for unsafe inputs so callers can
fail closed.

Args:
base: The base directory that the resulting path must be contained within.
rel_path: A peer-supplied relative path. Treated as untrusted.

Returns:
``base / rel_path`` on success, or ``None`` if ``rel_path`` is unsafe.
"""
if not rel_path:
return None
# Reject absolute paths. ``os.path.isabs`` covers POSIX roots and (on
# Windows) drive-prefixed paths; ``splitdrive`` catches e.g. ``C:foo`` which
# is technically relative but still anchored off ``base``.
if os.path.isabs(rel_path) or os.path.splitdrive(rel_path)[0]:
return None
try:
base_real = os.path.realpath(str(base))
candidate_real = os.path.realpath(os.path.join(base_real, rel_path))
except OSError:
return None
try:
common = os.path.commonpath([base_real, candidate_real])
except ValueError:
# Different drives on Windows, or otherwise incomparable.
return None
if common != base_real:
return None
return base / rel_path


class _ThreadingTCPServer(socketserver.ThreadingTCPServer):
"""A ThreadingTCPServer that holds a reference to a P2PNode."""

Expand Down Expand Up @@ -206,12 +244,12 @@ def handle_download(self, sock, payload: dict[str, Any]):
"""
rel_path_str = payload.get('rel_path')

if not rel_path_str or '..' in rel_path_str or rel_path_str.startswith('/'):
full_path = _safe_path_join(self.directory, rel_path_str) if rel_path_str else None
if full_path is None:
logging.error('Blocked unsafe P2P path request: %s', rel_path_str)
protocol.TCPMessage.send_file(sock, epath.Path('__INVALID__'))
return

full_path = self.directory / rel_path_str
if full_path.exists() and full_path.is_file():
protocol.TCPMessage.send_file(sock, full_path)
else:
Expand Down Expand Up @@ -268,7 +306,18 @@ def fetch_shard_from_peer(
futures = []
for f_meta in manifest:
rel_path_str = f_meta['rel_path']
dest_path = stage_dir / rel_path_str
dest_path = _safe_path_join(stage_dir, rel_path_str)
if dest_path is None:
logging.error(
'Rejecting unsafe manifest entry from peer %s:%d for'
' step=%d, process_index=%d: %r',
ip,
port,
step,
stored_process_index,
rel_path_str,
)
return False
futures.append(
exc.submit(
protocol.TCPClient.download, ip, port, rel_path_str, dest_path
Expand Down
130 changes: 130 additions & 0 deletions checkpoint/orbax/checkpoint/experimental/emergency/p2p/service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Unit tests for P2PNode service."""

import functools
import os
import threading
from unittest import mock

Expand All @@ -23,6 +24,135 @@
from orbax.checkpoint.experimental.emergency.p2p import service


class SafePathJoinTest(absltest.TestCase):
"""Tests for the ``_safe_path_join`` path-containment helper."""

def setUp(self):
super().setUp()
self.base = epath.Path(self.create_tempdir().full_path)

def test_valid_relative_path(self):
result = service._safe_path_join(self.base, '1/file1')
self.assertIsNotNone(result)
self.assertEqual(str(result), str(self.base / '1/file1'))

def test_valid_nested_path(self):
result = service._safe_path_join(self.base, '1/subdir/file2')
self.assertIsNotNone(result)

def test_empty_rel_path_rejected(self):
self.assertIsNone(service._safe_path_join(self.base, ''))

def test_absolute_posix_path_rejected(self):
self.assertIsNone(service._safe_path_join(self.base, '/etc/passwd'))

def test_parent_traversal_rejected(self):
self.assertIsNone(service._safe_path_join(self.base, '../secret'))

def test_deep_parent_traversal_rejected(self):
self.assertIsNone(
service._safe_path_join(self.base, 'a/b/../../../etc/passwd')
)

def test_traversal_to_siblings_rejected(self):
# Reference the base directory's sibling by name; must still be rejected.
sibling = os.path.basename(os.path.dirname(str(self.base))) + '_evil'
self.assertIsNone(service._safe_path_join(self.base, f'../{sibling}/x'))

def test_symlink_escape_rejected(self):
# A symlink inside ``base`` that points outside ``base`` must not be
# writable through ``_safe_path_join``.
outside = epath.Path(self.create_tempdir().full_path)
link_dir = self.base / 'link'
try:
os.symlink(str(outside), str(link_dir))
except (OSError, NotImplementedError):
self.skipTest('Symlinks not supported on this platform')
self.assertIsNone(service._safe_path_join(self.base, 'link/pwn'))


class PathTraversalRegressionTest(absltest.TestCase):
"""Regression tests for CVE-class path-traversal in the P2P service."""

def setUp(self):
super().setUp()
self.temp_dir = epath.Path(self.create_tempdir().full_path)
self.enter_context(
mock.patch.object(service, '_ThreadingTCPServer', autospec=True)
)
server = service._ThreadingTCPServer.return_value
server.server_address = ('localhost', 12345)
self.enter_context(
mock.patch.object(service.multihost, 'process_index', return_value=0)
)
self.enter_context(
mock.patch.object(
service.socket,
'getaddrinfo',
return_value=[(service.socket.AF_INET, 0, 0, '', ('127.0.0.1', 0))],
)
)
self.node = service.P2PNode(directory=self.temp_dir)

@mock.patch.object(service.protocol.TCPMessage, 'send_file', autospec=True)
def test_handle_download_blocks_traversal(self, mock_send_file):
"""Server must reject ``..`` components regardless of form."""
sock = mock.Mock()
for bad in [
'../unsafe',
'a/b/../../../etc/passwd',
'/etc/passwd',
'',
]:
mock_send_file.reset_mock()
self.node.handle_download(sock, {'rel_path': bad})
mock_send_file.assert_called_once_with(sock, epath.Path('__INVALID__'))

@mock.patch.object(service.shutil, 'rmtree', autospec=True)
@mock.patch.object(service.shutil, 'move', autospec=True)
@mock.patch.object(service.time, 'time', autospec=True)
@mock.patch.object(service.protocol.TCPClient, 'request', autospec=True)
@mock.patch.object(service.protocol.TCPClient, 'download', autospec=True)
def test_fetch_shard_from_peer_rejects_malicious_manifest(
self,
mock_download,
mock_request,
unused_mock_time,
unused_mock_move,
unused_mock_rmtree,
):
"""Client must refuse to download peer-supplied traversal paths.

A malicious peer returns a manifest whose ``rel_path`` tries to escape the
staging directory (e.g., writing a ``.pth`` file into site-packages).
``fetch_shard_from_peer`` must abort before any ``download`` call.
"""
mock_request.return_value = [
{'rel_path': '../../evil.pth', 'size': 10},
]
self.assertFalse(self.node.fetch_shard_from_peer('peer', 123, 1, 10))
mock_download.assert_not_called()

@mock.patch.object(service.shutil, 'rmtree', autospec=True)
@mock.patch.object(service.shutil, 'move', autospec=True)
@mock.patch.object(service.time, 'time', autospec=True)
@mock.patch.object(service.protocol.TCPClient, 'request', autospec=True)
@mock.patch.object(service.protocol.TCPClient, 'download', autospec=True)
def test_fetch_shard_from_peer_rejects_absolute_path(
self,
mock_download,
mock_request,
unused_mock_time,
unused_mock_move,
unused_mock_rmtree,
):
mock_request.return_value = [
{'rel_path': '/etc/passwd', 'size': 10},
]
self.assertFalse(self.node.fetch_shard_from_peer('peer', 123, 1, 10))
mock_download.assert_not_called()


class NodeHandlerTest(absltest.TestCase):

def setUp(self):
Expand Down