From e50498aa901bda681d68612218c455c9c707bd7f Mon Sep 17 00:00:00 2001 From: Sean Evans Date: Tue, 21 Apr 2026 08:24:18 -0400 Subject: [PATCH] Localize sandbox thread quota enforcement --- pyisolate/runtime/thread.py | 49 +++++++++++++++------------ tests/test_thread_extra.py | 66 +++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 21 deletions(-) diff --git a/pyisolate/runtime/thread.py b/pyisolate/runtime/thread.py index 404eaa3..8c87c61 100644 --- a/pyisolate/runtime/thread.py +++ b/pyisolate/runtime/thread.py @@ -121,26 +121,27 @@ def _guarded_urandom(n: int) -> bytes: return cap.bytes(n) -def _guarded_thread_start(self_thread: threading.Thread, *args, **kwargs): - sandbox = getattr(_thread_local, "sandbox", None) - if sandbox is None: - return _ORIG_THREAD_START(self_thread, *args, **kwargs) - sandbox._check_child_work_quota() - original_run = self_thread.run - - def _run_with_accounting(*r_args, **r_kwargs): - try: - return original_run(*r_args, **r_kwargs) - finally: - sandbox._child_work = max(0, sandbox._child_work - 1) +def _make_sandbox_thread_class(sandbox: "SandboxThread"): + class SandboxedThread(threading.Thread): + def start(self, *args, **kwargs): + sandbox._check_child_work_quota() + original_run = self.run + + def _run_with_accounting(*r_args, **r_kwargs): + try: + return original_run(*r_args, **r_kwargs) + finally: + sandbox._child_work = max(0, sandbox._child_work - 1) + + self.run = _run_with_accounting # type: ignore[assignment] + sandbox._child_work += 1 + try: + return _ORIG_THREAD_START(self, *args, **kwargs) + except Exception: + sandbox._child_work = max(0, sandbox._child_work - 1) + raise - self_thread.run = _run_with_accounting # type: ignore[assignment] - sandbox._child_work += 1 - try: - return _ORIG_THREAD_START(self_thread, *args, **kwargs) - except Exception: - sandbox._child_work = max(0, sandbox._child_work - 1) - raise + return SandboxedThread def _wrap_module(name: str, module): @@ -208,6 +209,14 @@ class GuardedSocket(socket.socket): mod.__dict__.update({k: getattr(random, k) for k in dir(random)}) mod.randbytes = _guarded_urandom return mod + if base == "threading": + sandbox = getattr(_thread_local, "sandbox", None) + if sandbox is None: + return module + mod = types.ModuleType("threading", module.__doc__) + mod.__dict__.update({k: getattr(threading, k) for k in dir(threading)}) + mod.Thread = _make_sandbox_thread_class(sandbox) + return mod if base == "pathlib": mod = types.ModuleType("pathlib", module.__doc__) mod.__dict__.update({k: getattr(module, k) for k in dir(module)}) @@ -654,7 +663,6 @@ def run(self) -> None: _thread_local.clock_capability = self._capabilities.get("clock") _thread_local.random_capability = self._capabilities.get("random") _thread_local.sandbox = self - threading.Thread.start = _guarded_thread_start builtins_dict = _SAFE_BUILTINS.copy() builtins_dict["open"] = _blocked_open @@ -742,6 +750,5 @@ def run(self) -> None: self._latency["inf"] += 1 _thread_local.active = False finally: - threading.Thread.start = _ORIG_THREAD_START if prev_handler is not None: signal.signal(signal.SIGXCPU, prev_handler) diff --git a/tests/test_thread_extra.py b/tests/test_thread_extra.py index 1bc945a..e999d6b 100644 --- a/tests/test_thread_extra.py +++ b/tests/test_thread_extra.py @@ -97,3 +97,69 @@ def fake_delete(path): # One call comes from thread startup attach, one from unique control message. assert calls["attach"] == 2 assert calls["delete"] == 0 + + +def test_sandbox_threading_patch_is_local_and_does_not_touch_global_start(): + iso.shutdown() + + original_start = threading.Thread.start + started_outside = {"count": 0} + failures = [] + + def outside_worker() -> None: + started_outside["count"] += 1 + + sb = SandboxThread(name="local-threading", child_work_max=1) + sb.start() + try: + sb.exec( + """ +import threading + +def worker(): + pass + +t = threading.Thread(target=worker) +t.start() +t.join() +post('sandbox-ok') +""" + ) + assert sb.recv(timeout=1) == "sandbox-ok" + + # A host thread should still use the original global start implementation. + host_thread = threading.Thread(target=outside_worker) + host_thread.start() + host_thread.join(timeout=1) + assert started_outside["count"] == 1 + assert threading.Thread.start is original_start + + # While sandbox work is running, repeatedly start host threads and + # ensure no global behavior bleed happens. + sb.exec( + """ +import threading +import time + +def worker(): + time.sleep(0.03) + +t = threading.Thread(target=worker) +t.start() +t.join() +post('sandbox-loop') +""" + ) + + for _ in range(10): + t = threading.Thread(target=outside_worker) + t.start() + t.join(timeout=1) + if threading.Thread.start is not original_start: + failures.append("threading.Thread.start changed globally") + + assert sb.recv(timeout=1) == "sandbox-loop" + assert not failures + assert threading.Thread.start is original_start + finally: + sb.stop()