Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pyisolate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,17 @@ def restore(*args, **kwargs): # type: ignore[no-redef]
from .editor import PolicyEditor, check_fs, check_tcp, parse_policy # noqa: F401
from .errors import (
CPUExceeded,
ChildWorkExceeded,
MemoryExceeded,
NetworkExceeded,
OpenFilesExceeded,
OutputExceeded,
PolicyAuthError,
PolicyError,
SandboxError,
TenantQuotaExceeded,
TimeoutError,
WallTimeExceeded,
)
from .logging import setup_structured_logging # noqa: F401

Expand Down Expand Up @@ -74,6 +80,12 @@ def migrate(*args, **kwargs): # type: ignore[no-redef]
"TimeoutError",
"MemoryExceeded",
"CPUExceeded",
"WallTimeExceeded",
"OpenFilesExceeded",
"NetworkExceeded",
"OutputExceeded",
"ChildWorkExceeded",
"TenantQuotaExceeded",
"sandbox",
"Pipeline",
"RestrictedExec",
Expand Down
24 changes: 24 additions & 0 deletions pyisolate/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,29 @@ class CPUExceeded(SandboxError):
"""Raised when a sandbox exceeds its CPU quota."""


class WallTimeExceeded(SandboxError):
"""Raised when a sandbox exceeds its wall-clock quota."""


class OpenFilesExceeded(SandboxError):
"""Raised when a sandbox exceeds its open-files quota."""


class NetworkExceeded(SandboxError):
"""Raised when a sandbox exceeds its network operations quota."""


class OutputExceeded(SandboxError):
"""Raised when a sandbox exceeds its output quota."""


class ChildWorkExceeded(SandboxError):
"""Raised when a sandbox exceeds its concurrent child-work quota."""


class TenantQuotaExceeded(SandboxError):
"""Raised when a tenant exceeds sustained quota."""


class OwnershipError(SandboxError):
"""Raised when a moved value is accessed."""
142 changes: 139 additions & 3 deletions pyisolate/runtime/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
import signal
import socket
import subprocess
import sys
import threading
import time
import tracemalloc
import types
import weakref
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Iterable, Optional
Expand All @@ -41,6 +43,7 @@

_ORIG_OPEN = builtins.open
_ORIG_SOCKET_CONNECT = socket.socket.connect
_ORIG_THREAD_START = threading.Thread.start


def _blocked_open(file, *args, **kwargs):
Expand All @@ -63,7 +66,19 @@ def _blocked_open(file, *args, **kwargs):
elif getattr(_thread_local, "active", False):
raise errors.PolicyError("file access blocked")

return _ORIG_OPEN(file, *args, **kwargs)
sandbox = getattr(_thread_local, "sandbox", None)
if sandbox is not None:
sandbox._check_open_files_quota()
opened = _ORIG_OPEN(file, *args, **kwargs)
if sandbox is None:
return opened
sandbox._open_files += 1

def _release():
sandbox._open_files = max(0, sandbox._open_files - 1)

weakref.finalize(opened, _release)
return opened


def _guarded_connect(self_socket: socket.socket, address: Iterable[str]):
Expand All @@ -81,6 +96,14 @@ def _guarded_connect(self_socket: socket.socket, address: Iterable[str]):
raise errors.PolicyError(f"connect blocked: {host}:{port}")
else:
raise errors.PolicyError(f"connect blocked: {host}:{port}")
sandbox = getattr(_thread_local, "sandbox", None)
if sandbox is not None:
sandbox._network_ops += 1
if (
sandbox.network_ops_max is not None
and sandbox._network_ops > sandbox.network_ops_max
):
raise errors.NetworkExceeded()
return _ORIG_SOCKET_CONNECT(self_socket, address)


Expand All @@ -98,6 +121,28 @@ def _guarded_urandom(n: int) -> bytes:
return cap.bytes(n)


def _guarded_thread_start(self_thread: threading.Thread, *args, **kwargs):
sandbox = getattr(_thread_local, "sandbox", None)
if sandbox is None:
return _ORIG_THREAD_START(self_thread, *args, **kwargs)
sandbox._check_child_work_quota()
original_run = self_thread.run

def _run_with_accounting(*r_args, **r_kwargs):
try:
return original_run(*r_args, **r_kwargs)
finally:
sandbox._child_work = max(0, sandbox._child_work - 1)

self_thread.run = _run_with_accounting # type: ignore[assignment]
sandbox._child_work += 1
try:
return _ORIG_THREAD_START(self_thread, *args, **kwargs)
except Exception:
sandbox._child_work = max(0, sandbox._child_work - 1)
raise


def _wrap_module(name: str, module):
base = name.split(".")[0]
if base == "time":
Expand Down Expand Up @@ -270,6 +315,11 @@ def __init__(
policy=None,
cpu_ms: Optional[int] = None,
mem_bytes: Optional[int] = None,
wall_time_ms: Optional[int] = None,
open_files_max: Optional[int] = None,
network_ops_max: Optional[int] = None,
output_bytes_max: Optional[int] = None,
child_work_max: Optional[int] = None,
allowed_imports: Optional[list[str]] = None,
on_violation: Optional[Callable[[str, Exception], None]] = None,
tracer: Optional["Tracer"] = None,
Expand All @@ -285,6 +335,11 @@ def __init__(
self.policy = policy
self.cpu_quota_ms = cpu_ms
self.mem_quota_bytes = mem_bytes
self.wall_time_ms = wall_time_ms
self.open_files_max = open_files_max
self.network_ops_max = network_ops_max
self.output_bytes_max = output_bytes_max
self.child_work_max = child_work_max
self.allowed_imports = self._merge_allowed_imports(policy, allowed_imports)
self._cpu_time = 0.0
self._mem_peak = 0
Expand All @@ -303,6 +358,11 @@ def __init__(
self._syscall_log: list[str] = []
self._capabilities = dict(capabilities or {})
self._quarantine_reason: str | None = None
self.termination_reason: str | None = None
self._open_files = 0
self._network_ops = 0
self._output_bytes = 0
self._child_work = 0
self._next_attach_msg_id = 1
self._seen_attach_msg_ids: set[int] = set()

Expand All @@ -318,8 +378,45 @@ def snapshot(self) -> dict:
else None,
"numa_node": self.numa_node,
"capabilities": sorted(self._capabilities),
"wall_time_ms": self.wall_time_ms,
"open_files_max": self.open_files_max,
"network_ops_max": self.network_ops_max,
"output_bytes_max": self.output_bytes_max,
"child_work_max": self.child_work_max,
}

@staticmethod
def _estimate_output_size(item: Any) -> int:
if isinstance(item, bytes):
return len(item)
if isinstance(item, str):
return len(item.encode("utf-8"))
return len(repr(item).encode("utf-8"))

def _post(self, item: Any) -> None:
self._output_bytes += self._estimate_output_size(item)
if self.output_bytes_max is not None and self._output_bytes > self.output_bytes_max:
raise errors.OutputExceeded()
self._outbox.put(item)

def _check_open_files_quota(self) -> None:
if self.open_files_max is not None and self._open_files >= self.open_files_max:
raise errors.OpenFilesExceeded()

def _check_child_work_quota(self) -> None:
if self.child_work_max is not None and self._child_work >= self.child_work_max:
raise errors.ChildWorkExceeded()

def _trace_guard(self, frame, event, arg):
if self.wall_time_ms is None:
return self._trace_guard
if self._start_time is None:
return self._trace_guard
elapsed_ms = (time.monotonic() - self._start_time) * 1000
if elapsed_ms > self.wall_time_ms:
raise errors.WallTimeExceeded()
return self._trace_guard

def enable_tracing(self) -> None:
"""Start recording guest operations."""
self._trace_enabled = True
Expand Down Expand Up @@ -416,6 +513,11 @@ def reset(
policy=None,
cpu_ms: Optional[int] = None,
mem_bytes: Optional[int] = None,
wall_time_ms: Optional[int] = None,
open_files_max: Optional[int] = None,
network_ops_max: Optional[int] = None,
output_bytes_max: Optional[int] = None,
child_work_max: Optional[int] = None,
allowed_imports: Optional[list[str]] = None,
numa_node: Optional[int] = None,
cgroup_path=None,
Expand All @@ -427,6 +529,11 @@ def reset(
self.policy = policy
self.cpu_quota_ms = cpu_ms
self.mem_quota_bytes = mem_bytes
self.wall_time_ms = wall_time_ms
self.open_files_max = open_files_max
self.network_ops_max = network_ops_max
self.output_bytes_max = output_bytes_max
self.child_work_max = child_work_max
self.numa_node = numa_node
self._bound_numa_node = None
self.allowed_imports = self._merge_allowed_imports(policy, allowed_imports)
Expand All @@ -441,6 +548,11 @@ def reset(
self._start_time = None
self._cgroup_path = cgroup_path
self._capabilities = dict(capabilities or {})
self.termination_reason = None
self._open_files = 0
self._network_ops = 0
self._output_bytes = 0
self._child_work = 0
# Request the sandbox thread to (re)attach itself to the new cgroup.
# The attachment must happen from the sandbox thread's context.
msg_id = self._next_attach_msg_id
Expand Down Expand Up @@ -485,14 +597,16 @@ def run(self) -> None:
self._cpu_time = 0.0
self._start_time = None

local_vars = {"post": self._outbox.put, "caps": self._capabilities}
local_vars = {"post": self._post, "caps": self._capabilities}

if self.numa_node is not None:
bind_current_thread(self.numa_node)
self._bound_numa_node = self.numa_node

while True:
payload = self._inbox.get()
if payload is _STOP:
break
if isinstance(payload, StopRequest):
break
if isinstance(payload, AttachCgroupRequest):
Expand All @@ -513,6 +627,8 @@ def run(self) -> None:
if self.numa_node is not None:
bind_current_thread(self.numa_node)
self._bound_numa_node = self.numa_node
if isinstance(payload, str):
payload = ExecRequest(source=payload)

allowed_tcp = None
allowed_fs = None
Expand All @@ -537,6 +653,8 @@ def run(self) -> None:
)
_thread_local.clock_capability = self._capabilities.get("clock")
_thread_local.random_capability = self._capabilities.get("random")
_thread_local.sandbox = self
threading.Thread.start = _guarded_thread_start

builtins_dict = _SAFE_BUILTINS.copy()
builtins_dict["open"] = _blocked_open
Expand All @@ -546,9 +664,11 @@ def run(self) -> None:
self._ops += 1
op_start = time.monotonic()
with self._tracer.start_span(f"sandbox:{self.name}"):
sys_trace_before = sys.gettrace()
try:
start_cpu = time.thread_time()
self._start_time = time.monotonic()
sys.settrace(self._trace_guard)
if isinstance(payload, CallRequest):
importer = builtins_dict["__import__"]
try:
Expand All @@ -563,7 +683,7 @@ def run(self) -> None:
res = object.__getattribute__(mod, func_name)(
*payload.args, **payload.kwargs
)
self._outbox.put(res)
self._post(res)
elif isinstance(payload, ExecRequest):
exec(payload.source, local_vars, local_vars)
else:
Expand All @@ -588,10 +708,25 @@ def run(self) -> None:
break
self._errors += 1
self._start_time = None
if isinstance(exc, errors.WallTimeExceeded):
self.termination_reason = "wall_time_exceeded"
elif isinstance(exc, errors.OpenFilesExceeded):
self.termination_reason = "open_files_exceeded"
elif isinstance(exc, errors.NetworkExceeded):
self.termination_reason = "network_exceeded"
elif isinstance(exc, errors.OutputExceeded):
self.termination_reason = "output_exceeded"
elif isinstance(exc, errors.ChildWorkExceeded):
self.termination_reason = "child_work_exceeded"
elif isinstance(exc, errors.CPUExceeded):
self.termination_reason = "cpu_exceeded"
elif isinstance(exc, errors.MemoryExceeded):
self.termination_reason = "memory_exceeded"
if self._on_violation and isinstance(exc, errors.PolicyError):
self._on_violation(self.name, exc)
self._outbox.put(exc)
finally:
sys.settrace(sys_trace_before)
self._start_time = None
duration = (time.monotonic() - op_start) * 1000
self._latency_sum += duration
Expand All @@ -607,5 +742,6 @@ def run(self) -> None:
self._latency["inf"] += 1
_thread_local.active = False
finally:
threading.Thread.start = _ORIG_THREAD_START
if prev_handler is not None:
signal.signal(signal.SIGXCPU, prev_handler)
Loading
Loading