Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 28 additions & 21 deletions pyisolate/runtime/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,26 +121,27 @@ def _guarded_urandom(n: int) -> bytes:
return cap.bytes(n)


def _guarded_thread_start(self_thread: threading.Thread, *args, **kwargs):
sandbox = getattr(_thread_local, "sandbox", None)
if sandbox is None:
return _ORIG_THREAD_START(self_thread, *args, **kwargs)
sandbox._check_child_work_quota()
original_run = self_thread.run

def _run_with_accounting(*r_args, **r_kwargs):
try:
return original_run(*r_args, **r_kwargs)
finally:
sandbox._child_work = max(0, sandbox._child_work - 1)
def _make_sandbox_thread_class(sandbox: "SandboxThread"):
class SandboxedThread(threading.Thread):
def start(self, *args, **kwargs):
sandbox._check_child_work_quota()
original_run = self.run

def _run_with_accounting(*r_args, **r_kwargs):
try:
return original_run(*r_args, **r_kwargs)
finally:
sandbox._child_work = max(0, sandbox._child_work - 1)

self.run = _run_with_accounting # type: ignore[assignment]
sandbox._child_work += 1
try:
return _ORIG_THREAD_START(self, *args, **kwargs)
except Exception:
sandbox._child_work = max(0, sandbox._child_work - 1)
raise

self_thread.run = _run_with_accounting # type: ignore[assignment]
sandbox._child_work += 1
try:
return _ORIG_THREAD_START(self_thread, *args, **kwargs)
except Exception:
sandbox._child_work = max(0, sandbox._child_work - 1)
raise
return SandboxedThread


def _wrap_module(name: str, module):
Expand Down Expand Up @@ -208,6 +209,14 @@ class GuardedSocket(socket.socket):
mod.__dict__.update({k: getattr(random, k) for k in dir(random)})
mod.randbytes = _guarded_urandom
return mod
if base == "threading":
sandbox = getattr(_thread_local, "sandbox", None)
if sandbox is None:
return module
mod = types.ModuleType("threading", module.__doc__)
mod.__dict__.update({k: getattr(threading, k) for k in dir(threading)})
mod.Thread = _make_sandbox_thread_class(sandbox)
return mod
if base == "pathlib":
mod = types.ModuleType("pathlib", module.__doc__)
mod.__dict__.update({k: getattr(module, k) for k in dir(module)})
Expand Down Expand Up @@ -654,7 +663,6 @@ def run(self) -> None:
_thread_local.clock_capability = self._capabilities.get("clock")
_thread_local.random_capability = self._capabilities.get("random")
_thread_local.sandbox = self
threading.Thread.start = _guarded_thread_start

builtins_dict = _SAFE_BUILTINS.copy()
builtins_dict["open"] = _blocked_open
Expand Down Expand Up @@ -742,6 +750,5 @@ def run(self) -> None:
self._latency["inf"] += 1
_thread_local.active = False
finally:
threading.Thread.start = _ORIG_THREAD_START
if prev_handler is not None:
signal.signal(signal.SIGXCPU, prev_handler)
66 changes: 66 additions & 0 deletions tests/test_thread_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,69 @@ def fake_delete(path):
# One call comes from thread startup attach, one from unique control message.
assert calls["attach"] == 2
assert calls["delete"] == 0


def test_sandbox_threading_patch_is_local_and_does_not_touch_global_start():
iso.shutdown()

original_start = threading.Thread.start
started_outside = {"count": 0}
failures = []

def outside_worker() -> None:
started_outside["count"] += 1

sb = SandboxThread(name="local-threading", child_work_max=1)
sb.start()
try:
sb.exec(
"""
import threading

def worker():
pass

t = threading.Thread(target=worker)
t.start()
t.join()
post('sandbox-ok')
"""
)
assert sb.recv(timeout=1) == "sandbox-ok"

# A host thread should still use the original global start implementation.
host_thread = threading.Thread(target=outside_worker)
host_thread.start()
host_thread.join(timeout=1)
assert started_outside["count"] == 1
assert threading.Thread.start is original_start

# While sandbox work is running, repeatedly start host threads and
# ensure no global behavior bleed happens.
sb.exec(
"""
import threading
import time

def worker():
time.sleep(0.03)

t = threading.Thread(target=worker)
t.start()
t.join()
post('sandbox-loop')
"""
)

for _ in range(10):
t = threading.Thread(target=outside_worker)
t.start()
t.join(timeout=1)
if threading.Thread.start is not original_start:
failures.append("threading.Thread.start changed globally")

assert sb.recv(timeout=1) == "sandbox-loop"
assert not failures
assert threading.Thread.start is original_start
finally:
sb.stop()
Loading