Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 65 additions & 48 deletions pyisolate/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,61 +241,78 @@ def spawn(
existing = self._sandboxes.get(name)
if existing is not None and existing.is_alive():
raise RuntimeError(f"sandbox '{name}' already exists")
usage_reserved = False
if tenant and tenant_quota is not None:
if self._tenant_usage.get(tenant, 0) >= tenant_quota:
raise TenantQuotaExceeded()
self._record_tenant_usage(tenant, 1)
usage_reserved = True

cg_path = cgroup.create(name, cpu_ms, mem_bytes)
temp_dir = recovery.allocate_temp_dir(name)
if self._warm_pool:
thread = self._warm_pool.pop()
thread.reset(
cg_path = None
temp_dir = None
thread = None
try:
cg_path = cgroup.create(name, cpu_ms, mem_bytes)
temp_dir = recovery.allocate_temp_dir(name)
if self._warm_pool:
thread = self._warm_pool.pop()
thread.reset(
name,
policy=policy,
cpu_ms=cpu_ms,
mem_bytes=mem_bytes,
wall_time_ms=wall_time_ms,
open_files_max=open_files_max,
network_ops_max=network_ops_max,
output_bytes_max=output_bytes_max,
child_work_max=child_work_max,
allowed_imports=allowed_imports,
numa_node=numa_node,
cgroup_path=cg_path,
capabilities=capabilities,
)
thread._on_violation = self._alerts.notify
thread._tracer = self._tracer
else:
thread = SandboxThread(
name=name,
policy=policy,
cpu_ms=cpu_ms,
mem_bytes=mem_bytes,
wall_time_ms=wall_time_ms,
open_files_max=open_files_max,
network_ops_max=network_ops_max,
output_bytes_max=output_bytes_max,
child_work_max=child_work_max,
allowed_imports=allowed_imports,
on_violation=self._alerts.notify,
tracer=self._tracer,
numa_node=numa_node,
cgroup_path=cg_path,
capabilities=capabilities,
)
thread.start()
thread._temp_dir = temp_dir
self._sandboxes[name] = thread
recovery.update_sandbox(
name,
policy=policy,
cpu_ms=cpu_ms,
mem_bytes=mem_bytes,
wall_time_ms=wall_time_ms,
open_files_max=open_files_max,
network_ops_max=network_ops_max,
output_bytes_max=output_bytes_max,
child_work_max=child_work_max,
allowed_imports=allowed_imports,
numa_node=numa_node,
cgroup_path=cg_path,
capabilities=capabilities,
)
thread._on_violation = self._alerts.notify
thread._tracer = self._tracer
else:
thread = SandboxThread(
name=name,
policy=policy,
cpu_ms=cpu_ms,
mem_bytes=mem_bytes,
wall_time_ms=wall_time_ms,
open_files_max=open_files_max,
network_ops_max=network_ops_max,
output_bytes_max=output_bytes_max,
child_work_max=child_work_max,
allowed_imports=allowed_imports,
on_violation=self._alerts.notify,
tracer=self._tracer,
numa_node=numa_node,
cgroup_path=cg_path,
capabilities=capabilities,
{
"name": name,
"cgroup_path": str(cg_path) if cg_path is not None else None,
"temp_dir": str(temp_dir),
},
)
thread.start()
thread._temp_dir = temp_dir
self._sandboxes[name] = thread
recovery.update_sandbox(
name,
{
"name": name,
"cgroup_path": str(cg_path) if cg_path is not None else None,
"temp_dir": str(temp_dir),
},
)
except Exception:
self._sandboxes.pop(name, None)
if thread is not None and thread.is_alive():
thread.stop()
cgroup.delete(cg_path)
if temp_dir is not None:
recovery.cleanup_temp_dir(temp_dir)
recovery.drop_sandbox(name)
if usage_reserved and tenant:
self._record_tenant_usage(tenant, -1)
raise
# Remove references to any terminated sandboxes
self._cleanup()
# Reset any temporary overrides of the name validation pattern to avoid
Expand Down
53 changes: 53 additions & 0 deletions tests/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,56 @@ def test_tenant_quota_is_durable(tmp_path, monkeypatch):
sup2.spawn("t2", tenant="acme", tenant_quota=1)
finally:
sup2.shutdown()


def test_spawn_start_failure_rolls_back_tenant_usage_and_ledger(tmp_path, monkeypatch):
ledger = tmp_path / "quota.log"
monkeypatch.setenv("PYISOLATE_QUOTA_LEDGER", str(ledger))

def fail_start(self):
raise RuntimeError("start failed")

monkeypatch.setattr("pyisolate.runtime.thread.SandboxThread.start", fail_start)

sup = iso.Supervisor()
try:
with pytest.raises(RuntimeError, match="start failed"):
sup.spawn("tenant-start-fail", tenant="acme", tenant_quota=1)
assert sup._tenant_usage.get("acme", 0) == 0
finally:
sup.shutdown()

assert ledger.read_text(encoding="utf-8").splitlines() == ["acme,1", "acme,-1"]

sup_replay = iso.Supervisor()
try:
assert sup_replay._tenant_usage.get("acme", 0) == 0
finally:
sup_replay.shutdown()


def test_spawn_registry_failure_rolls_back_tenant_usage_and_ledger(tmp_path, monkeypatch):
ledger = tmp_path / "quota.log"
monkeypatch.setenv("PYISOLATE_QUOTA_LEDGER", str(ledger))

def fail_update(*_args, **_kwargs):
raise RuntimeError("registry update failed")

monkeypatch.setattr("pyisolate.recovery.update_sandbox", fail_update)

sup = iso.Supervisor()
try:
with pytest.raises(RuntimeError, match="registry update failed"):
sup.spawn("tenant-registry-fail", tenant="acme", tenant_quota=1)
assert sup._tenant_usage.get("acme", 0) == 0
assert "tenant-registry-fail" not in sup._sandboxes
finally:
sup.shutdown()

assert ledger.read_text(encoding="utf-8").splitlines() == ["acme,1", "acme,-1"]

sup_replay = iso.Supervisor()
try:
assert sup_replay._tenant_usage.get("acme", 0) == 0
finally:
sup_replay.shutdown()
Loading