diff --git a/pyisolate/__init__.py b/pyisolate/__init__.py index 7be1206..4fcf92d 100644 --- a/pyisolate/__init__.py +++ b/pyisolate/__init__.py @@ -32,11 +32,17 @@ def restore(*args, **kwargs): # type: ignore[no-redef] from .editor import PolicyEditor, check_fs, check_tcp, parse_policy # noqa: F401 from .errors import ( CPUExceeded, + ChildWorkExceeded, MemoryExceeded, + NetworkExceeded, + OpenFilesExceeded, + OutputExceeded, PolicyAuthError, PolicyError, SandboxError, + TenantQuotaExceeded, TimeoutError, + WallTimeExceeded, ) from .logging import setup_structured_logging # noqa: F401 @@ -74,6 +80,12 @@ def migrate(*args, **kwargs): # type: ignore[no-redef] "TimeoutError", "MemoryExceeded", "CPUExceeded", + "WallTimeExceeded", + "OpenFilesExceeded", + "NetworkExceeded", + "OutputExceeded", + "ChildWorkExceeded", + "TenantQuotaExceeded", "sandbox", "Pipeline", "RestrictedExec", diff --git a/pyisolate/errors.py b/pyisolate/errors.py index b88599b..8d3a9d9 100644 --- a/pyisolate/errors.py +++ b/pyisolate/errors.py @@ -27,5 +27,29 @@ class CPUExceeded(SandboxError): """Raised when a sandbox exceeds its CPU quota.""" +class WallTimeExceeded(SandboxError): + """Raised when a sandbox exceeds its wall-clock quota.""" + + +class OpenFilesExceeded(SandboxError): + """Raised when a sandbox exceeds its open-files quota.""" + + +class NetworkExceeded(SandboxError): + """Raised when a sandbox exceeds its network operations quota.""" + + +class OutputExceeded(SandboxError): + """Raised when a sandbox exceeds its output quota.""" + + +class ChildWorkExceeded(SandboxError): + """Raised when a sandbox exceeds its concurrent child-work quota.""" + + +class TenantQuotaExceeded(SandboxError): + """Raised when a tenant exceeds sustained quota.""" + + class OwnershipError(SandboxError): """Raised when a moved value is accessed.""" diff --git a/pyisolate/runtime/thread.py b/pyisolate/runtime/thread.py index 45054fc..404eaa3 100644 --- a/pyisolate/runtime/thread.py +++ b/pyisolate/runtime/thread.py @@ -18,10 +18,12 @@ import signal import socket import subprocess +import sys import threading import time import tracemalloc import types +import weakref from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, Iterable, Optional @@ -41,6 +43,7 @@ _ORIG_OPEN = builtins.open _ORIG_SOCKET_CONNECT = socket.socket.connect +_ORIG_THREAD_START = threading.Thread.start def _blocked_open(file, *args, **kwargs): @@ -63,7 +66,19 @@ def _blocked_open(file, *args, **kwargs): elif getattr(_thread_local, "active", False): raise errors.PolicyError("file access blocked") - return _ORIG_OPEN(file, *args, **kwargs) + sandbox = getattr(_thread_local, "sandbox", None) + if sandbox is not None: + sandbox._check_open_files_quota() + opened = _ORIG_OPEN(file, *args, **kwargs) + if sandbox is None: + return opened + sandbox._open_files += 1 + + def _release(): + sandbox._open_files = max(0, sandbox._open_files - 1) + + weakref.finalize(opened, _release) + return opened def _guarded_connect(self_socket: socket.socket, address: Iterable[str]): @@ -81,6 +96,14 @@ def _guarded_connect(self_socket: socket.socket, address: Iterable[str]): raise errors.PolicyError(f"connect blocked: {host}:{port}") else: raise errors.PolicyError(f"connect blocked: {host}:{port}") + sandbox = getattr(_thread_local, "sandbox", None) + if sandbox is not None: + sandbox._network_ops += 1 + if ( + sandbox.network_ops_max is not None + and sandbox._network_ops > sandbox.network_ops_max + ): + raise errors.NetworkExceeded() return _ORIG_SOCKET_CONNECT(self_socket, address) @@ -98,6 +121,28 @@ def _guarded_urandom(n: int) -> bytes: return cap.bytes(n) +def _guarded_thread_start(self_thread: threading.Thread, *args, **kwargs): + sandbox = getattr(_thread_local, "sandbox", None) + if sandbox is None: + return _ORIG_THREAD_START(self_thread, *args, **kwargs) + sandbox._check_child_work_quota() + original_run = self_thread.run + + def _run_with_accounting(*r_args, **r_kwargs): + try: + return original_run(*r_args, **r_kwargs) + finally: + sandbox._child_work = max(0, sandbox._child_work - 1) + + self_thread.run = _run_with_accounting # type: ignore[assignment] + sandbox._child_work += 1 + try: + return _ORIG_THREAD_START(self_thread, *args, **kwargs) + except Exception: + sandbox._child_work = max(0, sandbox._child_work - 1) + raise + + def _wrap_module(name: str, module): base = name.split(".")[0] if base == "time": @@ -270,6 +315,11 @@ def __init__( policy=None, cpu_ms: Optional[int] = None, mem_bytes: Optional[int] = None, + wall_time_ms: Optional[int] = None, + open_files_max: Optional[int] = None, + network_ops_max: Optional[int] = None, + output_bytes_max: Optional[int] = None, + child_work_max: Optional[int] = None, allowed_imports: Optional[list[str]] = None, on_violation: Optional[Callable[[str, Exception], None]] = None, tracer: Optional["Tracer"] = None, @@ -285,6 +335,11 @@ def __init__( self.policy = policy self.cpu_quota_ms = cpu_ms self.mem_quota_bytes = mem_bytes + self.wall_time_ms = wall_time_ms + self.open_files_max = open_files_max + self.network_ops_max = network_ops_max + self.output_bytes_max = output_bytes_max + self.child_work_max = child_work_max self.allowed_imports = self._merge_allowed_imports(policy, allowed_imports) self._cpu_time = 0.0 self._mem_peak = 0 @@ -303,6 +358,11 @@ def __init__( self._syscall_log: list[str] = [] self._capabilities = dict(capabilities or {}) self._quarantine_reason: str | None = None + self.termination_reason: str | None = None + self._open_files = 0 + self._network_ops = 0 + self._output_bytes = 0 + self._child_work = 0 self._next_attach_msg_id = 1 self._seen_attach_msg_ids: set[int] = set() @@ -318,8 +378,45 @@ def snapshot(self) -> dict: else None, "numa_node": self.numa_node, "capabilities": sorted(self._capabilities), + "wall_time_ms": self.wall_time_ms, + "open_files_max": self.open_files_max, + "network_ops_max": self.network_ops_max, + "output_bytes_max": self.output_bytes_max, + "child_work_max": self.child_work_max, } + @staticmethod + def _estimate_output_size(item: Any) -> int: + if isinstance(item, bytes): + return len(item) + if isinstance(item, str): + return len(item.encode("utf-8")) + return len(repr(item).encode("utf-8")) + + def _post(self, item: Any) -> None: + self._output_bytes += self._estimate_output_size(item) + if self.output_bytes_max is not None and self._output_bytes > self.output_bytes_max: + raise errors.OutputExceeded() + self._outbox.put(item) + + def _check_open_files_quota(self) -> None: + if self.open_files_max is not None and self._open_files >= self.open_files_max: + raise errors.OpenFilesExceeded() + + def _check_child_work_quota(self) -> None: + if self.child_work_max is not None and self._child_work >= self.child_work_max: + raise errors.ChildWorkExceeded() + + def _trace_guard(self, frame, event, arg): + if self.wall_time_ms is None: + return self._trace_guard + if self._start_time is None: + return self._trace_guard + elapsed_ms = (time.monotonic() - self._start_time) * 1000 + if elapsed_ms > self.wall_time_ms: + raise errors.WallTimeExceeded() + return self._trace_guard + def enable_tracing(self) -> None: """Start recording guest operations.""" self._trace_enabled = True @@ -416,6 +513,11 @@ def reset( policy=None, cpu_ms: Optional[int] = None, mem_bytes: Optional[int] = None, + wall_time_ms: Optional[int] = None, + open_files_max: Optional[int] = None, + network_ops_max: Optional[int] = None, + output_bytes_max: Optional[int] = None, + child_work_max: Optional[int] = None, allowed_imports: Optional[list[str]] = None, numa_node: Optional[int] = None, cgroup_path=None, @@ -427,6 +529,11 @@ def reset( self.policy = policy self.cpu_quota_ms = cpu_ms self.mem_quota_bytes = mem_bytes + self.wall_time_ms = wall_time_ms + self.open_files_max = open_files_max + self.network_ops_max = network_ops_max + self.output_bytes_max = output_bytes_max + self.child_work_max = child_work_max self.numa_node = numa_node self._bound_numa_node = None self.allowed_imports = self._merge_allowed_imports(policy, allowed_imports) @@ -441,6 +548,11 @@ def reset( self._start_time = None self._cgroup_path = cgroup_path self._capabilities = dict(capabilities or {}) + self.termination_reason = None + self._open_files = 0 + self._network_ops = 0 + self._output_bytes = 0 + self._child_work = 0 # Request the sandbox thread to (re)attach itself to the new cgroup. # The attachment must happen from the sandbox thread's context. msg_id = self._next_attach_msg_id @@ -485,7 +597,7 @@ def run(self) -> None: self._cpu_time = 0.0 self._start_time = None - local_vars = {"post": self._outbox.put, "caps": self._capabilities} + local_vars = {"post": self._post, "caps": self._capabilities} if self.numa_node is not None: bind_current_thread(self.numa_node) @@ -493,6 +605,8 @@ def run(self) -> None: while True: payload = self._inbox.get() + if payload is _STOP: + break if isinstance(payload, StopRequest): break if isinstance(payload, AttachCgroupRequest): @@ -513,6 +627,8 @@ def run(self) -> None: if self.numa_node is not None: bind_current_thread(self.numa_node) self._bound_numa_node = self.numa_node + if isinstance(payload, str): + payload = ExecRequest(source=payload) allowed_tcp = None allowed_fs = None @@ -537,6 +653,8 @@ def run(self) -> None: ) _thread_local.clock_capability = self._capabilities.get("clock") _thread_local.random_capability = self._capabilities.get("random") + _thread_local.sandbox = self + threading.Thread.start = _guarded_thread_start builtins_dict = _SAFE_BUILTINS.copy() builtins_dict["open"] = _blocked_open @@ -546,9 +664,11 @@ def run(self) -> None: self._ops += 1 op_start = time.monotonic() with self._tracer.start_span(f"sandbox:{self.name}"): + sys_trace_before = sys.gettrace() try: start_cpu = time.thread_time() self._start_time = time.monotonic() + sys.settrace(self._trace_guard) if isinstance(payload, CallRequest): importer = builtins_dict["__import__"] try: @@ -563,7 +683,7 @@ def run(self) -> None: res = object.__getattribute__(mod, func_name)( *payload.args, **payload.kwargs ) - self._outbox.put(res) + self._post(res) elif isinstance(payload, ExecRequest): exec(payload.source, local_vars, local_vars) else: @@ -588,10 +708,25 @@ def run(self) -> None: break self._errors += 1 self._start_time = None + if isinstance(exc, errors.WallTimeExceeded): + self.termination_reason = "wall_time_exceeded" + elif isinstance(exc, errors.OpenFilesExceeded): + self.termination_reason = "open_files_exceeded" + elif isinstance(exc, errors.NetworkExceeded): + self.termination_reason = "network_exceeded" + elif isinstance(exc, errors.OutputExceeded): + self.termination_reason = "output_exceeded" + elif isinstance(exc, errors.ChildWorkExceeded): + self.termination_reason = "child_work_exceeded" + elif isinstance(exc, errors.CPUExceeded): + self.termination_reason = "cpu_exceeded" + elif isinstance(exc, errors.MemoryExceeded): + self.termination_reason = "memory_exceeded" if self._on_violation and isinstance(exc, errors.PolicyError): self._on_violation(self.name, exc) self._outbox.put(exc) finally: + sys.settrace(sys_trace_before) self._start_time = None duration = (time.monotonic() - op_start) * 1000 self._latency_sum += duration @@ -607,5 +742,6 @@ def run(self) -> None: self._latency["inf"] += 1 _thread_local.active = False finally: + threading.Thread.start = _ORIG_THREAD_START if prev_handler is not None: signal.signal(signal.SIGXCPU, prev_handler) diff --git a/pyisolate/supervisor.py b/pyisolate/supervisor.py index bae83a8..2c6a395 100644 --- a/pyisolate/supervisor.py +++ b/pyisolate/supervisor.py @@ -9,6 +9,7 @@ import importlib import logging +import os import re import threading from pathlib import Path @@ -16,7 +17,7 @@ from . import cgroup, recovery from .capabilities import ROOT, RootCapability -from .errors import PolicyAuthError +from .errors import PolicyAuthError, TenantQuotaExceeded from .observability.alerts import AlertManager from .observability.trace import Tracer from .runtime.protocol import CapabilityHandle, ControlRequest @@ -71,6 +72,11 @@ def reset(self) -> None: policy=self._thread.policy, cpu_ms=self._thread.cpu_quota_ms, mem_bytes=self._thread.mem_quota_bytes, + wall_time_ms=self._thread.wall_time_ms, + open_files_max=self._thread.open_files_max, + network_ops_max=self._thread.network_ops_max, + output_bytes_max=self._thread.output_bytes_max, + child_work_max=self._thread.child_work_max, allowed_imports=sorted(self._thread.allowed_imports) if self._thread.allowed_imports is not None else None, @@ -116,6 +122,10 @@ def __del__(self): def stats(self): return self._thread.stats + @property + def termination_reason(self) -> str | None: + return self._thread.termination_reason + class Supervisor: """Main supervisor owning all sandboxes.""" @@ -142,6 +152,34 @@ def __init__( self._watchdog = ResourceWatchdog(self) self._watchdog.start() self._policy_token: str | None = None + self._tenant_usage: dict[str, int] = {} + self._quota_ledger = os.environ.get("PYISOLATE_QUOTA_LEDGER") + self._replay_quota_ledger() + + def _replay_quota_ledger(self) -> None: + if not self._quota_ledger: + return + path = Path(self._quota_ledger) + if not path.exists(): + return + for line in path.read_text(encoding="utf-8").splitlines(): + tenant, _, delta_str = line.partition(",") + if not tenant: + continue + try: + delta = int(delta_str) + except ValueError: + continue + self._tenant_usage[tenant] = self._tenant_usage.get(tenant, 0) + delta + + def _record_tenant_usage(self, tenant: str, delta: int = 1) -> None: + self._tenant_usage[tenant] = self._tenant_usage.get(tenant, 0) + delta + if not self._quota_ledger: + return + path = Path(self._quota_ledger) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="utf-8") as handle: + handle.write(f"{tenant},{delta}\n") def _recover_state(self) -> None: """Recover durable supervisor state and clean stale resources.""" @@ -171,9 +209,16 @@ def spawn( policy=None, cpu_ms: Optional[int] = None, mem_bytes: Optional[int] = None, + wall_time_ms: Optional[int] = None, + open_files_max: Optional[int] = None, + network_ops_max: Optional[int] = None, + output_bytes_max: Optional[int] = None, + child_work_max: Optional[int] = None, allowed_imports: Optional[list[str]] = None, numa_node: Optional[int] = None, capabilities: Optional[dict[str, object]] = None, + tenant: Optional[str] = None, + tenant_quota: Optional[int] = None, ) -> Sandbox: """Create and start a sandbox thread.""" global NAME_PATTERN @@ -196,6 +241,10 @@ def spawn( existing = self._sandboxes.get(name) if existing is not None and existing.is_alive(): raise RuntimeError(f"sandbox '{name}' already exists") + if tenant and tenant_quota is not None: + if self._tenant_usage.get(tenant, 0) >= tenant_quota: + raise TenantQuotaExceeded() + self._record_tenant_usage(tenant, 1) cg_path = cgroup.create(name, cpu_ms, mem_bytes) temp_dir = recovery.allocate_temp_dir(name) @@ -206,6 +255,11 @@ def spawn( policy=policy, cpu_ms=cpu_ms, mem_bytes=mem_bytes, + wall_time_ms=wall_time_ms, + open_files_max=open_files_max, + network_ops_max=network_ops_max, + output_bytes_max=output_bytes_max, + child_work_max=child_work_max, allowed_imports=allowed_imports, numa_node=numa_node, cgroup_path=cg_path, @@ -219,6 +273,11 @@ def spawn( policy=policy, cpu_ms=cpu_ms, mem_bytes=mem_bytes, + wall_time_ms=wall_time_ms, + open_files_max=open_files_max, + network_ops_max=network_ops_max, + output_bytes_max=output_bytes_max, + child_work_max=child_work_max, allowed_imports=allowed_imports, on_violation=self._alerts.notify, tracer=self._tracer, @@ -348,6 +407,11 @@ def recycle(self, name: str) -> Sandbox: policy=snap["policy"], cpu_ms=snap["cpu_ms"], mem_bytes=snap["mem_bytes"], + wall_time_ms=snap["wall_time_ms"], + open_files_max=snap["open_files_max"], + network_ops_max=snap["network_ops_max"], + output_bytes_max=snap["output_bytes_max"], + child_work_max=snap["child_work_max"], allowed_imports=snap["allowed_imports"], numa_node=snap["numa_node"], ) diff --git a/tests/test_supervisor.py b/tests/test_supervisor.py index 2b08262..f630c6b 100644 --- a/tests/test_supervisor.py +++ b/tests/test_supervisor.py @@ -176,3 +176,34 @@ def test_quarantine_and_recycle(): assert revived.recv(timeout=0.2) == "ok" finally: sup.shutdown() + + +def test_sandbox_termination_reason_passthrough(): + sup = iso.Supervisor() + try: + sb = sup.spawn("term", output_bytes_max=1) + sb.exec("post('xx')") + with pytest.raises(iso.OutputExceeded): + sb.recv(timeout=0.5) + assert sb.termination_reason == "output_exceeded" + finally: + sup.shutdown() + + +def test_tenant_quota_is_durable(tmp_path, monkeypatch): + ledger = tmp_path / "quota.log" + monkeypatch.setenv("PYISOLATE_QUOTA_LEDGER", str(ledger)) + + sup1 = iso.Supervisor() + try: + sb = sup1.spawn("t1", tenant="acme", tenant_quota=1) + sb.close() + finally: + sup1.shutdown() + + sup2 = iso.Supervisor() + try: + with pytest.raises(iso.TenantQuotaExceeded): + sup2.spawn("t2", tenant="acme", tenant_quota=1) + finally: + sup2.shutdown() diff --git a/tests/test_thread_quota.py b/tests/test_thread_quota.py index 38d22f6..1c72ddc 100644 --- a/tests/test_thread_quota.py +++ b/tests/test_thread_quota.py @@ -45,3 +45,74 @@ def test_sigxcpu_handler_scoped_to_sandbox_thread(): sb.run() assert sb.recv(timeout=1) is thread._sigxcpu_handler assert signal.getsignal(signal.SIGXCPU) is orig + + +def test_wall_time_quota_hard_stop(): + sb = thread.SandboxThread("wall", wall_time_ms=5) + sb.start() + try: + sb.exec("while True:\n pass") + with pytest.raises(errors.WallTimeExceeded): + sb.recv(timeout=1) + assert sb.termination_reason == "wall_time_exceeded" + finally: + sb.stop() + + +def test_open_files_quota_hard_stop(tmp_path): + first = tmp_path / "a.txt" + second = tmp_path / "b.txt" + first.write_text("a", encoding="utf-8") + second.write_text("b", encoding="utf-8") + + policy = type("Policy", (), {"fs": {str(tmp_path)}})() + sb = thread.SandboxThread("files", policy=policy, open_files_max=1) + sb.start() + try: + sb.exec( + f"f1 = open({str(first)!r}, 'r')\n" + f"f2 = open({str(second)!r}, 'r')\n" + "post('ok')" + ) + with pytest.raises(errors.OpenFilesExceeded): + sb.recv(timeout=1) + assert sb.termination_reason == "open_files_exceeded" + finally: + sb.stop() + + +def test_output_quota_hard_stop(): + sb = thread.SandboxThread("output", output_bytes_max=4) + sb.start() + try: + sb.exec("post('12345')") + with pytest.raises(errors.OutputExceeded): + sb.recv(timeout=1) + assert sb.termination_reason == "output_exceeded" + finally: + sb.stop() + + +def test_network_ops_quota_hard_stop(): + policy = type("Policy", (), {"tcp": {"127.0.0.1:80"}})() + sb = thread.SandboxThread("net", policy=policy, network_ops_max=0) + sb.start() + try: + sb.exec("import socket\nsocket.socket().connect(('127.0.0.1', 80))") + with pytest.raises(errors.NetworkExceeded): + sb.recv(timeout=1) + assert sb.termination_reason == "network_exceeded" + finally: + sb.stop() + + +def test_child_work_quota_hard_stop(): + sb = thread.SandboxThread("child", child_work_max=0) + sb.start() + try: + sb.exec("import threading\nthreading.Thread(target=lambda: None).start()") + with pytest.raises(errors.ChildWorkExceeded): + sb.recv(timeout=1) + assert sb.termination_reason == "child_work_exceeded" + finally: + sb.stop()