diff --git a/pyisolate/supervisor.py b/pyisolate/supervisor.py index 2c6a395..337dacc 100644 --- a/pyisolate/supervisor.py +++ b/pyisolate/supervisor.py @@ -241,61 +241,78 @@ def spawn( existing = self._sandboxes.get(name) if existing is not None and existing.is_alive(): raise RuntimeError(f"sandbox '{name}' already exists") + usage_reserved = False if tenant and tenant_quota is not None: if self._tenant_usage.get(tenant, 0) >= tenant_quota: raise TenantQuotaExceeded() self._record_tenant_usage(tenant, 1) + usage_reserved = True - cg_path = cgroup.create(name, cpu_ms, mem_bytes) - temp_dir = recovery.allocate_temp_dir(name) - if self._warm_pool: - thread = self._warm_pool.pop() - thread.reset( + cg_path = None + temp_dir = None + thread = None + try: + cg_path = cgroup.create(name, cpu_ms, mem_bytes) + temp_dir = recovery.allocate_temp_dir(name) + if self._warm_pool: + thread = self._warm_pool.pop() + thread.reset( + name, + policy=policy, + cpu_ms=cpu_ms, + mem_bytes=mem_bytes, + wall_time_ms=wall_time_ms, + open_files_max=open_files_max, + network_ops_max=network_ops_max, + output_bytes_max=output_bytes_max, + child_work_max=child_work_max, + allowed_imports=allowed_imports, + numa_node=numa_node, + cgroup_path=cg_path, + capabilities=capabilities, + ) + thread._on_violation = self._alerts.notify + thread._tracer = self._tracer + else: + thread = SandboxThread( + name=name, + policy=policy, + cpu_ms=cpu_ms, + mem_bytes=mem_bytes, + wall_time_ms=wall_time_ms, + open_files_max=open_files_max, + network_ops_max=network_ops_max, + output_bytes_max=output_bytes_max, + child_work_max=child_work_max, + allowed_imports=allowed_imports, + on_violation=self._alerts.notify, + tracer=self._tracer, + numa_node=numa_node, + cgroup_path=cg_path, + capabilities=capabilities, + ) + thread.start() + thread._temp_dir = temp_dir + self._sandboxes[name] = thread + recovery.update_sandbox( name, - policy=policy, - cpu_ms=cpu_ms, - mem_bytes=mem_bytes, - wall_time_ms=wall_time_ms, - open_files_max=open_files_max, - network_ops_max=network_ops_max, - output_bytes_max=output_bytes_max, - child_work_max=child_work_max, - allowed_imports=allowed_imports, - numa_node=numa_node, - cgroup_path=cg_path, - capabilities=capabilities, - ) - thread._on_violation = self._alerts.notify - thread._tracer = self._tracer - else: - thread = SandboxThread( - name=name, - policy=policy, - cpu_ms=cpu_ms, - mem_bytes=mem_bytes, - wall_time_ms=wall_time_ms, - open_files_max=open_files_max, - network_ops_max=network_ops_max, - output_bytes_max=output_bytes_max, - child_work_max=child_work_max, - allowed_imports=allowed_imports, - on_violation=self._alerts.notify, - tracer=self._tracer, - numa_node=numa_node, - cgroup_path=cg_path, - capabilities=capabilities, + { + "name": name, + "cgroup_path": str(cg_path) if cg_path is not None else None, + "temp_dir": str(temp_dir), + }, ) - thread.start() - thread._temp_dir = temp_dir - self._sandboxes[name] = thread - recovery.update_sandbox( - name, - { - "name": name, - "cgroup_path": str(cg_path) if cg_path is not None else None, - "temp_dir": str(temp_dir), - }, - ) + except Exception: + self._sandboxes.pop(name, None) + if thread is not None and thread.is_alive(): + thread.stop() + cgroup.delete(cg_path) + if temp_dir is not None: + recovery.cleanup_temp_dir(temp_dir) + recovery.drop_sandbox(name) + if usage_reserved and tenant: + self._record_tenant_usage(tenant, -1) + raise # Remove references to any terminated sandboxes self._cleanup() # Reset any temporary overrides of the name validation pattern to avoid diff --git a/tests/test_supervisor.py b/tests/test_supervisor.py index f630c6b..a293144 100644 --- a/tests/test_supervisor.py +++ b/tests/test_supervisor.py @@ -207,3 +207,56 @@ def test_tenant_quota_is_durable(tmp_path, monkeypatch): sup2.spawn("t2", tenant="acme", tenant_quota=1) finally: sup2.shutdown() + + +def test_spawn_start_failure_rolls_back_tenant_usage_and_ledger(tmp_path, monkeypatch): + ledger = tmp_path / "quota.log" + monkeypatch.setenv("PYISOLATE_QUOTA_LEDGER", str(ledger)) + + def fail_start(self): + raise RuntimeError("start failed") + + monkeypatch.setattr("pyisolate.runtime.thread.SandboxThread.start", fail_start) + + sup = iso.Supervisor() + try: + with pytest.raises(RuntimeError, match="start failed"): + sup.spawn("tenant-start-fail", tenant="acme", tenant_quota=1) + assert sup._tenant_usage.get("acme", 0) == 0 + finally: + sup.shutdown() + + assert ledger.read_text(encoding="utf-8").splitlines() == ["acme,1", "acme,-1"] + + sup_replay = iso.Supervisor() + try: + assert sup_replay._tenant_usage.get("acme", 0) == 0 + finally: + sup_replay.shutdown() + + +def test_spawn_registry_failure_rolls_back_tenant_usage_and_ledger(tmp_path, monkeypatch): + ledger = tmp_path / "quota.log" + monkeypatch.setenv("PYISOLATE_QUOTA_LEDGER", str(ledger)) + + def fail_update(*_args, **_kwargs): + raise RuntimeError("registry update failed") + + monkeypatch.setattr("pyisolate.recovery.update_sandbox", fail_update) + + sup = iso.Supervisor() + try: + with pytest.raises(RuntimeError, match="registry update failed"): + sup.spawn("tenant-registry-fail", tenant="acme", tenant_quota=1) + assert sup._tenant_usage.get("acme", 0) == 0 + assert "tenant-registry-fail" not in sup._sandboxes + finally: + sup.shutdown() + + assert ledger.read_text(encoding="utf-8").splitlines() == ["acme,1", "acme,-1"] + + sup_replay = iso.Supervisor() + try: + assert sup_replay._tenant_usage.get("acme", 0) == 0 + finally: + sup_replay.shutdown()