diff --git a/cli/stack/src/flowmesh_cli_stack/assets/.env.example b/cli/stack/src/flowmesh_cli_stack/assets/.env.example index 7ba2bf2c..c34520bd 100644 --- a/cli/stack/src/flowmesh_cli_stack/assets/.env.example +++ b/cli/stack/src/flowmesh_cli_stack/assets/.env.example @@ -93,6 +93,13 @@ SSH_DEFAULT_IDLE_SEC= SSH_MAX_TTL_SEC= SSH_POLL_INTERVAL_SEC= SSH_STOP_TIMEOUT_SEC= +SSH_MAX_CPU= +SSH_MAX_MEMORY= +SSH_MAX_PIDS= +# Whether to apply requested GPU limits to SSH tasks. +# If false, SSH tasks are allocated all available GPUs +# regardless of their resource requests. +ENABLE_SSH_GPU_LIMIT=false # ==== General Settings ==== TZ=Asia/Singapore diff --git a/cli/stack/src/flowmesh_cli_stack/env_schema.py b/cli/stack/src/flowmesh_cli_stack/env_schema.py index 9c7edb3c..30644028 100644 --- a/cli/stack/src/flowmesh_cli_stack/env_schema.py +++ b/cli/stack/src/flowmesh_cli_stack/env_schema.py @@ -301,6 +301,24 @@ EnvVar("SSH_MAX_TTL_SEC", var_type=EnvVarType.FLOAT, min_value=0), EnvVar("SSH_POLL_INTERVAL_SEC", var_type=EnvVarType.FLOAT, min_value=0), EnvVar("SSH_STOP_TIMEOUT_SEC", var_type=EnvVarType.FLOAT, min_value=0), + EnvVar( + "SSH_MAX_CPU", + var_type=EnvVarType.FLOAT, + min_value=0, + min_inclusive=False, + ), + EnvVar("SSH_MAX_MEMORY"), + EnvVar("SSH_MAX_PIDS", var_type=EnvVarType.INT, min_value=1), + EnvVar( + "ENABLE_SSH_GPU_LIMIT", + "false", + var_type=EnvVarType.BOOL, + description=[ + "Whether to apply requested GPU limits to SSH tasks.", + "If false, SSH tasks are allocated all available GPUs", + "regardless of their resource requests.", + ], + ), ], ), EnvSection( diff --git a/docs/ENV.md b/docs/ENV.md index a6d613fc..9bde44e0 100644 --- a/docs/ENV.md +++ b/docs/ENV.md @@ -79,3 +79,22 @@ Spark), set `DOCKER_GPU_RUNTIME=` in the stack env. | `SUPERVISOR_GRPC_KEEPALIVE_PERMIT_WITHOUT_CALLS` | `true` | gRPC keepalive | | `SUPERVISOR_GRPC_EXTERNAL_PORT` | – | External port (when port-forwarded) | | `SERVER_GRPC_TLS_*` | – | TLS certificate files | + +## SSH session resource caps + +When `enable_ssh` is true on a Docker worker, these configured +ceilings bound every SSH session container spawned by that worker. +Unset values mean unbounded (host-wide access). + +| Variable | Default | Description | +|----------|---------|-------------| +| `SSH_MAX_CPU` | – | Max CPU cores per SSH container (float, e.g. `4` or `2.5`). Sets Docker `nano_cpus`. | +| `SSH_MAX_MEMORY` | – | Max memory per SSH container (e.g. `8Gi`, `512Mi`, or a byte count). Sets Docker `mem_limit`. | +| `SSH_MAX_PIDS` | – | Max PIDs per SSH container. Sets Docker `pids_limit`. Admin-only — not user-overridable. | +| `ENABLE_SSH_GPU_LIMIT` | `false` | When `true`, mount only the GPU subset matching the spec (`count` / `type` / `memory`); otherwise mount all worker GPUs. | + +The effective CPU/memory limit is `min(spec.resources.hardware, worker +cap)`. A task that requests more than the worker cap is dispatched to +another worker if one has a larger cap; otherwise the dispatcher +follows its standard requeue/retry behavior. The worker logs a startup +warning if SSH is enabled with no cap configured. diff --git a/sdk/src/flowmesh/models/__init__.py b/sdk/src/flowmesh/models/__init__.py index 6c4f47e5..e2043805 100644 --- a/sdk/src/flowmesh/models/__init__.py +++ b/sdk/src/flowmesh/models/__init__.py @@ -38,6 +38,7 @@ HostInfo, MemoryInfo, NetworkInfo, + SSHLimits, StorageInfo, Worker, WorkerHardware, @@ -78,6 +79,7 @@ "NodeWorkerInfo", "ProfileSummary", "ResultEnvelope", + "SSHLimits", "StorageInfo", "TaskInfo", "TaskStatus", diff --git a/sdk/src/flowmesh/models/workers.py b/sdk/src/flowmesh/models/workers.py index 746e997e..d10ed84c 100644 --- a/sdk/src/flowmesh/models/workers.py +++ b/sdk/src/flowmesh/models/workers.py @@ -73,6 +73,12 @@ class WorkerHardware(BaseModel): extra: dict[str, Any] | None = None +class SSHLimits(BaseModel): + max_cpu_cores: float | None = None + max_memory_bytes: int | None = None + max_pids: int | None = None + + class Worker(BaseModel): id: str alias: str | None = None @@ -85,6 +91,7 @@ class Worker(BaseModel): pid: int | None = None env: dict[str, Any] = Field(default_factory=dict) hardware: WorkerHardware | None = None + ssh_limits: SSHLimits | None = None tags: list[str] = Field(default_factory=list) last_seen: str | None = None cached_models: list[str] = Field(default_factory=list) diff --git a/sdk/stack/src/flowmesh_stack/env_schema.py b/sdk/stack/src/flowmesh_stack/env_schema.py index 25fd970e..825945b2 100644 --- a/sdk/stack/src/flowmesh_stack/env_schema.py +++ b/sdk/stack/src/flowmesh_stack/env_schema.py @@ -1,6 +1,7 @@ """Environment schema definitions and pure validation helpers.""" import enum +import operator from collections.abc import Callable, Iterable, Mapping from dataclasses import dataclass, field from logging import _nameToLevel as LOG_LEVELS @@ -34,7 +35,9 @@ class EnvVar: use_default: bool = False choices: Iterable[str] | None = None min_value: float | None = None + min_inclusive: bool = True max_value: float | None = None + max_inclusive: bool = True min_length: int | None = None ensure_path: Literal["error", "warn", "create"] | None = None url_schemes: set[str] | None = None @@ -133,19 +136,13 @@ def validate_env_values( if int_value is None: errors.append(f"{var.key} must be an integer") continue - if var.min_value is not None and int_value < var.min_value: - errors.append(f"{var.key} must be >= {int(var.min_value)}") - if var.max_value is not None and int_value > var.max_value: - errors.append(f"{var.key} must be <= {int(var.max_value)}") + _check_value_range(var, int_value, errors) case EnvVarType.FLOAT: float_value = parse_float(raw) if float_value is None: errors.append(f"{var.key} must be a number") continue - if var.min_value is not None and float_value < var.min_value: - errors.append(f"{var.key} must be >= {var.min_value}") - if var.max_value is not None and float_value > var.max_value: - errors.append(f"{var.key} must be <= {var.max_value}") + _check_value_range(var, float_value, errors) case EnvVarType.BOOL: if parse_bool(raw) is None: errors.append( @@ -213,6 +210,36 @@ def require_all_or_none( errors.append(f"Either all or none of {', '.join(keys)} must be set") +_COMPARE_OPS = { + "<": operator.lt, + "<=": operator.le, + ">": operator.gt, + ">=": operator.ge, +} + + +def _check_value_range(var: EnvVar, value: float, errors: list[str]) -> None: + comp: Literal["<", "<=", ">", ">="] + if var.min_value is not None: + comp = ">=" if var.min_inclusive else ">" + _check_bound(var, value, var.min_value, comp, errors) + if var.max_value is not None: + comp = "<=" if var.max_inclusive else "<" + _check_bound(var, value, var.max_value, comp, errors) + + +def _check_bound( + var: EnvVar, + value: float, + bound: float, + comparison: Literal["<", "<=", ">", ">="], + errors: list[str], +) -> None: + op = _COMPARE_OPS[comparison] + if not op(value, bound): + errors.append(f"{var.key} must be {comparison} {bound}") + + def _ensure_path(raw: str, var: EnvVar, errors: list[str], warnings: list[str]) -> None: if not raw: errors.append(f"{var.key} must be a non-empty path") diff --git a/src/server/env.py b/src/server/env.py index d1271f2d..7afba7e6 100644 --- a/src/server/env.py +++ b/src/server/env.py @@ -75,6 +75,10 @@ SSH_MAX_TTL_SEC: float | None = parse_float_env("SSH_MAX_TTL_SEC") SSH_POLL_INTERVAL_SEC: float | None = parse_float_env("SSH_POLL_INTERVAL_SEC") SSH_STOP_TIMEOUT_SEC: float | None = parse_float_env("SSH_STOP_TIMEOUT_SEC") +SSH_MAX_CPU: float | None = parse_float_env("SSH_MAX_CPU") +SSH_MAX_MEMORY: str | None = os.getenv("SSH_MAX_MEMORY", "").strip() or None +SSH_MAX_PIDS: int | None = parse_int_env("SSH_MAX_PIDS") +ENABLE_SSH_GPU_LIMIT: bool = parse_bool_env("ENABLE_SSH_GPU_LIMIT", False) LOG_FILE: str = os.getenv("LOG_FILE", "server.log") LOG_MAX_BYTES: int = int(os.getenv("LOG_MAX_BYTES", 5_242_880)) diff --git a/src/server/registries/worker.py b/src/server/registries/worker.py index b7b35430..94023ef8 100644 --- a/src/server/registries/worker.py +++ b/src/server/registries/worker.py @@ -1,5 +1,4 @@ import json -import re from collections.abc import Iterable, Sequence from typing import Any @@ -10,14 +9,22 @@ StopMessage, TaskMessage, ) +from shared.schemas.worker import SSHLimits from shared.tasks import TaskEnvelope from shared.tasks.components.resources import GPURequirements +from shared.tasks.specs import SSHSpecStrict, SSHSpecTemplate from shared.tasks.worker_message import ( WorkerHardware, WorkerStatus, WorkerTaskMessage, ) from shared.utils import new_worker_id, now_iso, parse_mem_to_bytes +from shared.utils.hardware import ( + normalize_gpu_type, + parse_gpu_memory_bytes, + select_matching_gpu_indices, + unified_gpu_memory_satisfies, +) from ..clients.redis import ( WORKER_EVENT_CHANNEL, @@ -48,6 +55,9 @@ class Worker(BaseModel): hardware: WorkerHardware | None = Field( default=None, description="Hardware metadata." ) + ssh_limits: SSHLimits | None = Field( + default=None, description="Configured ceiling on SSH session resources." + ) tags: list[str] = Field(default_factory=list, description="Worker tags.") last_seen: str | None = Field(default=None, description="Last heartbeat timestamp.") cached_models: list[str] = Field( @@ -484,14 +494,34 @@ def hw_satisfies(worker: Worker, task: TaskEnvelope) -> bool: mem_needed = requirements.memory gpu_req = requirements.gpu + # Consider SSH hardware limits for SSH tasks. + ssh_caps = ( + worker.ssh_limits + if isinstance(task.spec, (SSHSpecStrict, SSHSpecTemplate)) + and worker.ssh_limits is not None + else None + ) + if cpu_needed is not None: - cpu_cores = None if hw is None else hw.cpu.logical_cores + cpu_cores: float | None = None if hw is None else hw.cpu.logical_cores + if ( + ssh_caps is not None + and ssh_caps.max_cpu_cores is not None + and cpu_cores is not None + ): + cpu_cores = min(cpu_cores, ssh_caps.max_cpu_cores) if cpu_cores is None or cpu_cores < cpu_needed: return False if mem_needed: required_bytes = parse_mem_to_bytes(str(mem_needed)) or 0 available = 0 if hw is None else (hw.memory.total_bytes or 0) + if ( + ssh_caps is not None + and ssh_caps.max_memory_bytes is not None + and available > 0 + ): + available = min(available, ssh_caps.max_memory_bytes) if available < required_bytes: return False @@ -511,34 +541,27 @@ def _gpu_meets_requirements(hw: WorkerHardware, gpu_req: GPURequirements) -> boo required_count = int(required_count) except Exception: required_count = None - required_type = str(gpu_req.type or "").strip().lower() - if required_type in {"", "any", "auto", "*"}: - required_type = "" - required_memory = gpu_req.memory - required_memory_bytes = ( - parse_mem_to_bytes(str(required_memory)) if required_memory else None - ) + required_type = normalize_gpu_type(gpu_req.type) + required_memory_bytes = parse_gpu_memory_bytes(gpu_req.memory) + needed = required_count or 1 entries = hw.gpu.devices if entries: if required_count is not None and len(entries) < required_count: return False - if required_type: - pattern = re.compile(re.escape(required_type), re.IGNORECASE) - if not any(pattern.search(entry.name) for entry in entries): - return False - if required_memory_bytes: - needed = required_count or 1 - eligible = sum( - 1 - for entry in entries - if (entry.memory_total_bytes or 0) >= required_memory_bytes - ) - if eligible < needed and not _unified_gpu_memory_satisfies( - hw, required_memory_bytes, needed - ): - return False - return True + if len(select_matching_gpu_indices(entries, gpu_req)) >= needed: + return True + # Unified-memory fallback: when memory is the binding constraint and + # the worker exposes a unified GPU/system pool large enough to cover + # the request, still admit it. + if required_memory_bytes is None: + return False + type_only_req = GPURequirements( + count=gpu_req.count, type=gpu_req.type, memory=None + ) + if len(select_matching_gpu_indices(entries, type_only_req)) < needed: + return False + return unified_gpu_memory_satisfies(hw, required_memory_bytes, needed) # Fallback when workers report aggregate GPU data instead of per-device entries. count = 0 if hw is None else len(hw.gpu.devices) @@ -554,13 +577,12 @@ def _gpu_meets_requirements(hw: WorkerHardware, gpu_req: GPURequirements) -> boo if required_memory_bytes: total_mem = 0 if first_gpu is None else (first_gpu.memory_total_bytes or 0) - if total_mem <= 0 and _unified_gpu_memory_satisfies( - hw, required_memory_bytes, required_count or 1 + if total_mem <= 0 and unified_gpu_memory_satisfies( + hw, required_memory_bytes, needed ): return True if total_mem <= 0: return False - needed = required_count or 1 per_gpu = total_mem / max(needed, 1) if per_gpu < required_memory_bytes: return False @@ -577,18 +599,6 @@ def dedicated_gpu_memory_total_bytes(hw: WorkerHardware | None) -> int: return total -def _unified_gpu_memory_satisfies( - hw: WorkerHardware, required_memory_bytes: int, required_count: int -) -> bool: - if not hw.gpu.memory_is_unified: - return False - shared_total = hw.gpu.shared_memory_total_bytes or 0 - if shared_total <= 0: - return False - per_gpu_share = shared_total / max(required_count, 1) - return per_gpu_share >= required_memory_bytes - - def _parse_worker_from_redis( worker_id: str, value: dict[str, Any] | None ) -> Worker | None: @@ -621,6 +631,12 @@ def _ensure_str_list(items: Any) -> list[str]: if hardware_json is None else WorkerHardware.model_validate_json(hardware_json) ) + ssh_limits_json = value.get("ssh_limits_json") + ssh_limits = ( + None + if ssh_limits_json is None + else SSHLimits.model_validate_json(ssh_limits_json) + ) tags = _loads(value.get("tags_json"), []) cached_models = _ensure_str_list(_loads(value.get("cache_models_json"), [])) cached_datasets = _ensure_str_list(_loads(value.get("cache_datasets_json"), [])) @@ -650,6 +666,7 @@ def _ensure_str_list(items: Any) -> list[str]: pid=pid, env=env, hardware=hardware, + ssh_limits=ssh_limits, tags=tags, last_seen=value.get("last_seen"), cached_models=cached_models, diff --git a/src/server/supervisor/adapters/docker.py b/src/server/supervisor/adapters/docker.py index 2811814c..18327930 100644 --- a/src/server/supervisor/adapters/docker.py +++ b/src/server/supervisor/adapters/docker.py @@ -14,6 +14,8 @@ from docker.types import DeviceRequest from pydantic import BaseModel, Field +from shared.schemas.worker import SSHLimits +from shared.utils import parse_mem_to_bytes from shared.utils.docker import sanitize_container_name from ... import env @@ -28,7 +30,7 @@ WorkerFactory, WorkerTokenType, ) -from .utils import get_worker_image_name +from .utils import get_worker_image_name, to_env_str _STOP_TIMEOUT = 30 # seconds _PROVIDER_NAME = "docker" @@ -113,6 +115,17 @@ class SSHConfig(BaseModel): """Container status poll interval in seconds""" stop_timeout_sec: float | None = env.SSH_STOP_TIMEOUT_SEC """Seconds to wait when stopping a session container""" + max_cpu: float | None = env.SSH_MAX_CPU + """Maximum CPU cores accessible to an SSH session container""" + max_memory: str | None = env.SSH_MAX_MEMORY + """Maximum memory accessible to an SSH session container (e.g. "8Gi")""" + max_pids: int | None = env.SSH_MAX_PIDS + """Maximum number of PIDs inside an SSH session container""" + enable_gpu_limit: bool = env.ENABLE_SSH_GPU_LIMIT + """Whether to apply requested GPU limits to SSH session containers. + + If false, SSH session containers are allocated all available GPUs regardless of + their resource requests.""" def to_env(self) -> dict[str, str]: """Return env vars to inject into the worker container.""" @@ -124,8 +137,30 @@ def to_env(self) -> dict[str, str]: "SSH_MAX_TTL_SEC": self.max_ttl_sec, "SSH_POLL_INTERVAL_SEC": self.poll_interval_sec, "SSH_STOP_TIMEOUT_SEC": self.stop_timeout_sec, + "SSH_MAX_CPU": self.max_cpu, + "SSH_MAX_MEMORY": self.max_memory, + "SSH_MAX_PIDS": self.max_pids, + "ENABLE_SSH_GPU_LIMIT": self.enable_gpu_limit, } - return {k: str(v) for k, v in mapping.items() if v is not None} + return {k: to_env_str(v) for k, v in mapping.items() if v is not None} + + def to_limits(self) -> SSHLimits | None: + """Project the admin cap into a wire-ready ``SSHLimits``.""" + memory_bytes: int | None = None + if self.max_memory is not None: + memory_bytes = parse_mem_to_bytes(self.max_memory) + if memory_bytes is None: + raise ValueError( + f"SSH_MAX_MEMORY value {self.max_memory!r} is not a valid " + "memory string (e.g. '8Gi', '512Mi', or a byte count)" + ) + if self.max_cpu is None and memory_bytes is None and self.max_pids is None: + return None + return SSHLimits( + max_cpu_cores=self.max_cpu, + max_memory_bytes=memory_bytes, + max_pids=self.max_pids, + ) class DockerWorkerConfig(WorkerConfig): @@ -210,6 +245,7 @@ def get_info(self) -> DockerWorkerInfo: provider=_PROVIDER_NAME, status=self.status, hardware=hardware, + ssh_limits=self.config.ssh.to_limits() if self.config.enable_ssh else None, ) async def start(self) -> bool: @@ -488,7 +524,8 @@ def _base_environment(self) -> dict[str, str]: environment["WORKER_NETWORK_MODE"] = f"container:{self.container_name}" environment["WORKER_CONTAINER_NAME"] = self.container_name environment["SSH_NETWORK_NAME"] = _SSH_NETWORK_NAME - environment.update(self.config.ssh.to_env()) + if self.config.enable_ssh: + environment.update(self.config.ssh.to_env()) return environment def _apply_worker_type_settings( diff --git a/src/server/supervisor/schemas.py b/src/server/supervisor/schemas.py index 4e9a0916..f03dbe65 100644 --- a/src/server/supervisor/schemas.py +++ b/src/server/supervisor/schemas.py @@ -3,6 +3,8 @@ from pydantic import BaseModel, Field +from shared.schemas.worker import SSHLimits + from .. import env from ..schemas.node import WorkerHardware @@ -25,6 +27,13 @@ class WorkerInfo(BaseModel): hardware: Annotated[ WorkerHardware | None, Field(default=None, description="Hardware metadata") ] + ssh_limits: Annotated[ + SSHLimits | None, + Field( + default=None, + description="Configured ceiling on SSH session resources.", + ), + ] = None __all__ = ["WorkerHardware", "WorkerInfo", "WorkerStatus"] diff --git a/src/shared/schemas/worker.py b/src/shared/schemas/worker.py index 0ff4b09e..ebb967ee 100644 --- a/src/shared/schemas/worker.py +++ b/src/shared/schemas/worker.py @@ -1,5 +1,7 @@ from enum import StrEnum +from pydantic import BaseModel, Field + class WorkerStatus(StrEnum): UNKNOWN = "UNKNOWN" @@ -8,4 +10,24 @@ class WorkerStatus(StrEnum): BUSY = "BUSY" -__all__ = ["WorkerStatus"] +class SSHLimits(BaseModel): + """Per-worker ceiling for resources accessible by SSH session containers. + + Populated from the worker's ``SSH_MAX_*`` configuration. Used by the dispatcher to + filter workers for SSH tasks and by the worker at runtime to clamp the spawned + container's cgroup limits. + """ + + max_cpu_cores: float | None = Field( + default=None, description="Maximum CPU cores accessible to an SSH session." + ) + max_memory_bytes: int | None = Field( + default=None, + description="Maximum memory in bytes accessible to an SSH session.", + ) + max_pids: int | None = Field( + default=None, description="Maximum number of PIDs inside an SSH session." + ) + + +__all__ = ["SSHLimits", "WorkerStatus"] diff --git a/src/shared/utils/hardware.py b/src/shared/utils/hardware.py new file mode 100644 index 00000000..1ac5a47d --- /dev/null +++ b/src/shared/utils/hardware.py @@ -0,0 +1,98 @@ +"""Helpers for parsing GPU requirement specs and matching them against devices.""" + +import re + +from shared.tasks.components.resources import GPURequirements +from shared.tasks.worker_message import GpuInfo, WorkerHardware +from shared.utils.parsing import parse_mem_to_bytes + +_GPU_TYPE_WILDCARDS = frozenset({"", "any", "auto", "*"}) + + +def normalize_gpu_type(value: str | None) -> str | None: + """Lowercase the type; ``None`` for wildcards (``''``/``any``/``auto``/``*``).""" + if value is None: + return None + normalized = value.strip().lower() + return None if normalized in _GPU_TYPE_WILDCARDS else normalized + + +def gpu_type_pattern(value: str | None) -> re.Pattern[str] | None: + """Case-insensitive substring matcher, or ``None`` for wildcard.""" + normalized = normalize_gpu_type(value) + if normalized is None: + return None + return re.compile(re.escape(normalized), re.IGNORECASE) + + +def parse_gpu_memory_bytes(value: str | int | float | None) -> int | None: + """Parse ``GPURequirements.memory`` (str / int / float / None) to bytes. + + Returns ``None`` for ``None`` and for unparsable strings. + """ + if value is None: + return None + if isinstance(value, str): + return parse_mem_to_bytes(value) + return int(value) + + +def gpu_device_matches( + device: GpuInfo, + *, + type_pattern: re.Pattern[str] | None = None, + min_memory_bytes: int | None = None, +) -> bool: + """Per-device predicate; ``None`` arg means 'no constraint'.""" + if type_pattern is not None and not type_pattern.search(device.name or ""): + return False + return ( + min_memory_bytes is None or (device.memory_total_bytes or 0) >= min_memory_bytes + ) + + +def unified_gpu_memory_satisfies( + hw: WorkerHardware, required_memory_bytes: int, required_count: int +) -> bool: + """Pessimistic per-slot share of a unified GPU/system memory pool. + + Returns ``True`` only when ``hw.gpu.memory_is_unified`` and + ``shared_memory_total_bytes / required_count >= required_memory_bytes``. + """ + if not hw.gpu.memory_is_unified: + return False + shared_total = hw.gpu.shared_memory_total_bytes or 0 + if shared_total <= 0: + return False + per_gpu_share = shared_total / max(required_count, 1) + return per_gpu_share >= required_memory_bytes + + +def select_matching_gpu_indices( + devices: list[GpuInfo], + gpu_req: GPURequirements, + *, + limit: int | None = None, +) -> list[int]: + """Indices of devices that individually pass ``gpu_req``'s type + memory. + + Stops after ``limit`` matches when set. + """ + if limit is not None and limit <= 0: + return [] + type_pattern = gpu_type_pattern(gpu_req.type) + min_memory_bytes = ( + parse_gpu_memory_bytes(gpu_req.memory) if gpu_req.memory else None + ) + result: list[int] = [] + for idx, device in enumerate(devices): + if not gpu_device_matches( + device, + type_pattern=type_pattern, + min_memory_bytes=min_memory_bytes, + ): + continue + result.append(idx) + if limit is not None and len(result) >= limit: + break + return result diff --git a/src/worker/config.py b/src/worker/config.py index f8973d43..9e2461d0 100644 --- a/src/worker/config.py +++ b/src/worker/config.py @@ -11,7 +11,13 @@ from pathlib import Path from typing import Any -from shared.utils.parsing import parse_bool_env, parse_float_env, parse_int_env +from shared.schemas.worker import SSHLimits +from shared.utils.parsing import ( + parse_bool_env, + parse_float_env, + parse_int_env, + parse_mem_to_bytes, +) from .utils.health import get_hb_config @@ -37,6 +43,8 @@ class WorkerConfig: executor_idle_cleanup_sec: float | None enable_mp_executors: bool docker_gpu_runtime: str | None + ssh_limits: SSHLimits | None + enable_ssh_gpu_limit: bool grpc_keepalive_time_ms: int | None = None grpc_keepalive_timeout_ms: int | None = None network_mode: str | None = None @@ -110,6 +118,36 @@ def from_env() -> "WorkerConfig": "WORKER_EXECUTOR_IDLE_CLEANUP_SEC", 60 ) + ssh_max_cpu = parse_float_env("SSH_MAX_CPU") + if ssh_max_cpu is not None and ssh_max_cpu <= 0: + raise SystemExit("SSH_MAX_CPU must be positive") + ssh_max_memory_raw = os.getenv("SSH_MAX_MEMORY", "").strip() or None + ssh_max_memory_bytes: int | None = None + if ssh_max_memory_raw is not None: + ssh_max_memory_bytes = parse_mem_to_bytes(ssh_max_memory_raw) + if ssh_max_memory_bytes is None or ssh_max_memory_bytes <= 0: + raise SystemExit( + f"SSH_MAX_MEMORY value {ssh_max_memory_raw!r} is not a valid " + "memory string (e.g. '8Gi', '512Mi', or a positive byte count)" + ) + ssh_max_pids = parse_int_env("SSH_MAX_PIDS") + if ssh_max_pids is not None and ssh_max_pids <= 0: + raise SystemExit("SSH_MAX_PIDS must be positive") + ssh_limits = ( + None + if ( + ssh_max_cpu is None + and ssh_max_memory_bytes is None + and ssh_max_pids is None + ) + else SSHLimits( + max_cpu_cores=ssh_max_cpu, + max_memory_bytes=ssh_max_memory_bytes, + max_pids=ssh_max_pids, + ) + ) + enable_ssh_gpu_limit = parse_bool_env("ENABLE_SSH_GPU_LIMIT", False) + return WorkerConfig( worker_token=worker_token, owner_principal=owner_principal, @@ -130,6 +168,8 @@ def from_env() -> "WorkerConfig": executor_idle_cleanup_sec=executor_idle_cleanup_sec, enable_mp_executors=enable_mp_executors, docker_gpu_runtime=docker_gpu_runtime, + ssh_limits=ssh_limits, + enable_ssh_gpu_limit=enable_ssh_gpu_limit, grpc_keepalive_time_ms=grpc_keepalive_time_ms, grpc_keepalive_timeout_ms=grpc_keepalive_timeout_ms, network_mode=network_mode, diff --git a/src/worker/executors/agent_executor.py b/src/worker/executors/agent_executor.py index 4768d18a..c12d4003 100644 --- a/src/worker/executors/agent_executor.py +++ b/src/worker/executors/agent_executor.py @@ -17,8 +17,6 @@ from datasets import load_dataset from shared.tasks.specs import AgentSpecStrict -from worker.config import WorkerConfig -from worker.lifecycle import Lifecycle from .base_executor import ExecutionError, Executor, ExecutorTask from .utils.checkpoints import ( @@ -61,10 +59,8 @@ class AgentExecutor(Executor): name = "agent" - def __init__( - self, config: WorkerConfig, lifecycle: Lifecycle | None = None - ) -> None: - super().__init__(config, lifecycle) + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) self._initialized = False self._tasks: list[str] = [] diff --git a/src/worker/executors/base_executor.py b/src/worker/executors/base_executor.py index e2ba0134..f4f43f90 100644 --- a/src/worker/executors/base_executor.py +++ b/src/worker/executors/base_executor.py @@ -29,7 +29,7 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict: from shared.tasks import MergedChildTaskStrict from shared.tasks.specs import TaskSpecStrictBase -from shared.tasks.worker_message import WorkerTaskMessage +from shared.tasks.worker_message import WorkerHardware, WorkerTaskMessage from worker.config import WorkerConfig from worker.lifecycle import Lifecycle @@ -57,10 +57,14 @@ class Executor(ABC): name: str = "executor" def __init__( - self, config: WorkerConfig, lifecycle: Lifecycle | None = None + self, + config: WorkerConfig, + hardware: WorkerHardware | None = None, + lifecycle: Lifecycle | None = None, ) -> None: super().__init__() self._config = config + self._hardware = hardware self._lifecycle = lifecycle def emit_update(self, task_id: str, payload: dict[str, Any]) -> None: diff --git a/src/worker/executors/diffusers_executor.py b/src/worker/executors/diffusers_executor.py index 1216acb7..f74fb8e4 100644 --- a/src/worker/executors/diffusers_executor.py +++ b/src/worker/executors/diffusers_executor.py @@ -15,8 +15,6 @@ from PIL import Image from shared.tasks.specs import DiffusionSpecStrict -from worker.config import WorkerConfig -from worker.lifecycle import Lifecycle from ..utils.logging import configure_hf_library_logging from .base_executor import ExecutionError, Executor, ExecutorTask @@ -52,10 +50,8 @@ class DiffusersExecutor(DataMixin, Executor): name = "diffusers" - def __init__( - self, config: WorkerConfig, lifecycle: Lifecycle | None = None - ) -> None: - super().__init__(config, lifecycle) + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) self._pipe: Any | None = None self._device: str | None = None self._model_name: str | None = None diff --git a/src/worker/executors/dpo_executor.py b/src/worker/executors/dpo_executor.py index 2054b77f..7dc1e8d2 100644 --- a/src/worker/executors/dpo_executor.py +++ b/src/worker/executors/dpo_executor.py @@ -27,8 +27,6 @@ from shared.tasks.specs import DPOSpecStrict from shared.utils.manifest import scratch_dir -from worker.config import WorkerConfig -from worker.lifecycle import Lifecycle from ..utils.logging import configure_hf_library_logging from .base_executor import ExecutionError, Executor, ExecutorTask @@ -52,10 +50,8 @@ class DPOExecutor(TrainingMixin, Executor): name = "dpo_executor" - def __init__( - self, config: WorkerConfig, lifecycle: Lifecycle | None = None - ) -> None: - super().__init__(config, lifecycle) + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) self._model_name: str | None = None self._current_model: PreTrainedModel | None = None self._current_ref_model: PreTrainedModel | None = None diff --git a/src/worker/executors/lora_sft_executor.py b/src/worker/executors/lora_sft_executor.py index 35d4d0fd..f40c99ad 100644 --- a/src/worker/executors/lora_sft_executor.py +++ b/src/worker/executors/lora_sft_executor.py @@ -19,8 +19,6 @@ from trl.trainer.sft_trainer import SFTTrainer from shared.tasks.specs import LoRASFTSpecStrict -from worker.config import WorkerConfig -from worker.lifecycle import Lifecycle from ..utils.logging import configure_hf_library_logging from .base_executor import ExecutionError, Executor, ExecutorTask @@ -55,10 +53,8 @@ class LoRASFTExecutor(TrainingMixin, Executor): name = "lora_sft_executor" - def __init__( - self, config: WorkerConfig, lifecycle: Lifecycle | None = None - ) -> None: - super().__init__(config, lifecycle) + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) self._model_name: str | None = None self._current_model: Any | None = None self._current_trainer: Any | None = None diff --git a/src/worker/executors/mp_executor.py b/src/worker/executors/mp_executor.py index 03b709be..f2379fe2 100644 --- a/src/worker/executors/mp_executor.py +++ b/src/worker/executors/mp_executor.py @@ -21,6 +21,7 @@ import psutil +from shared.tasks.worker_message import WorkerHardware from worker.config import WorkerConfig from .base_executor import ExecutionError, Executor, ExecutorTask @@ -232,6 +233,7 @@ def _configure_worker_logging(log_queue: Queue | None) -> None: def _executor_worker( executor_cls: type[Executor], config: WorkerConfig, + hardware: WorkerHardware | None, cmd_queue: mp.Queue, result_queue: mp.Queue, log_queue: Queue | None, @@ -256,7 +258,7 @@ def _executor_worker( with MPLogHandler(enabled=log_queue is not None): executor: Executor | None = None try: - executor = executor_cls(config) + executor = executor_cls(config, hardware) except Exception: logger.warning( "Failed to initialize executor in worker process", exc_info=True @@ -362,10 +364,10 @@ def __init__( self, executor_cls: type[Executor], config: WorkerConfig, - *, + hardware: WorkerHardware | None = None, start_method: str = "spawn", ) -> None: - super().__init__(config) + super().__init__(config, hardware) self._executor_cls = executor_cls self._ctx = mp.get_context(start_method) inner_name = getattr(executor_cls, "name", executor_cls.__name__) @@ -398,6 +400,7 @@ def _start_process(self) -> None: args=( self._executor_cls, self._config, + self._hardware, self._cmd_q, self._res_q, self._log_q, diff --git a/src/worker/executors/omni_executor_base.py b/src/worker/executors/omni_executor_base.py index ff2e04dc..87e36f93 100644 --- a/src/worker/executors/omni_executor_base.py +++ b/src/worker/executors/omni_executor_base.py @@ -22,8 +22,6 @@ from shared.tasks.specs import TaskSpecStrictBase from shared.utils.parsing import to_bool, to_int -from ..config import WorkerConfig -from ..lifecycle import Lifecycle from .base_executor import ExecutionError, Executor, ExecutorTask from .mixins.inference import InferenceMixin from .utils.checkpoints import maybe_upload_artifacts, maybe_upload_traces @@ -53,10 +51,8 @@ class OmniExecutorBase(InferenceMixin, Executor): _TASK_SPEC_TYPE: ClassVar[type[TaskSpecStrictBase]] - def __init__( - self, config: WorkerConfig, lifecycle: Lifecycle | None = None - ) -> None: - super().__init__(config, lifecycle) + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) self._omni: Any | None = None self._model_name: str | None = None self._omni_spec: tuple[Any, ...] | None = None diff --git a/src/worker/executors/ppo_executor.py b/src/worker/executors/ppo_executor.py index 9cc3f056..1cc47183 100644 --- a/src/worker/executors/ppo_executor.py +++ b/src/worker/executors/ppo_executor.py @@ -36,8 +36,6 @@ from shared.tasks.specs import PPOSpecStrict from shared.utils.manifest import scratch_dir from shared.utils.parsing import safe_float, safe_int, to_bool -from worker.config import WorkerConfig -from worker.lifecycle import Lifecycle from ..utils.logging import configure_hf_library_logging from .base_executor import ExecutionError, Executor, ExecutorTask @@ -384,10 +382,8 @@ class PPOExecutor(TrainingMixin, Executor): name = "ppo_executor" - def __init__( - self, config: WorkerConfig, lifecycle: Lifecycle | None = None - ) -> None: - super().__init__(config, lifecycle) + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) self._model_name: str | None = None self._policy_model: PreTrainedModel | None = None self._ref_model: PreTrainedModel | None = None diff --git a/src/worker/executors/sft_executor.py b/src/worker/executors/sft_executor.py index 3c9b23fa..4361c94d 100644 --- a/src/worker/executors/sft_executor.py +++ b/src/worker/executors/sft_executor.py @@ -24,8 +24,6 @@ from shared.tasks.specs import SFTSpecStrict, TaskSpecStrictBase from shared.utils.manifest import scratch_dir -from worker.config import WorkerConfig -from worker.lifecycle import Lifecycle from ..utils.logging import configure_hf_library_logging from .base_executor import ExecutionError, Executor, ExecutorTask @@ -48,10 +46,8 @@ class SFTExecutor(TrainingMixin, Executor): name = "sft_executor" - def __init__( - self, config: WorkerConfig, lifecycle: Lifecycle | None = None - ) -> None: - super().__init__(config, lifecycle) + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) self._model_name: str | None = None self._current_model: Any | None = None self._current_trainer: Any | None = None diff --git a/src/worker/executors/ssh_executor.py b/src/worker/executors/ssh_executor.py index 9837a446..ef6d7958 100644 --- a/src/worker/executors/ssh_executor.py +++ b/src/worker/executors/ssh_executor.py @@ -29,18 +29,24 @@ import requests +from shared.tasks.components.resources import GPURequirements from shared.tasks.specs.ssh import ( SSHInputSpec, SSHMountSpec, SSHOutputSpec, SSHSpecStrict, ) -from shared.utils import new_ssh_session_id, parse_float_env +from shared.tasks.worker_message import WorkerHardware +from shared.utils import new_ssh_session_id, parse_float_env, parse_mem_to_bytes +from shared.utils.hardware import ( + parse_gpu_memory_bytes, + select_matching_gpu_indices, + unified_gpu_memory_satisfies, +) from shared.utils.http import auth_headers from shared.utils.manifest import ARTIFACTS_DIR, prepare_output_dir from worker.config import WorkerConfig from worker.executors.utils.checkpoints import maybe_upload_artifacts -from worker.lifecycle import Lifecycle from .base_executor import ( ExecutionError, @@ -155,9 +161,18 @@ class SSHConfig: mounts: list[SSHMountSpec] poll_interval_sec: float stop_timeout_sec: float + cpu_limit: float | None + memory_limit_bytes: int | None + pids_limit: int | None + gpu_device_ids: list[str] @classmethod - def from_spec(cls, spec: SSHSpecStrict) -> "SSHConfig": + def from_spec( + cls, + spec: SSHSpecStrict, + worker_cfg: WorkerConfig, + hardware: WorkerHardware | None = None, + ) -> "SSHConfig": """Build a resolved config from a task spec, env vars, and defaults.""" has_gpu = bool(os.getenv("WORKER_HOST_GPU_ID", "").strip()) fallback_image = _DEFAULT_IMAGE_GPU if has_gpu else _DEFAULT_IMAGE_CPU @@ -173,6 +188,10 @@ def from_spec(cls, spec: SSHSpecStrict) -> "SSHConfig": if (ssh_output := spec.sshOutput) else None ) + cpu_limit, memory_limit_bytes, pids_limit = _resolve_resource_limits( + spec, worker_cfg + ) + gpu_device_ids = _resolve_gpu_devices(spec, worker_cfg, hardware) return cls( image=spec.image or default_image, interactive=bool(spec.interactive), @@ -189,7 +208,165 @@ def from_spec(cls, spec: SSHSpecStrict) -> "SSHConfig": mounts=list(spec.mounts or []), poll_interval_sec=poll_interval_sec, stop_timeout_sec=stop_timeout_sec, + cpu_limit=cpu_limit, + memory_limit_bytes=memory_limit_bytes, + pids_limit=pids_limit, + gpu_device_ids=gpu_device_ids, + ) + + +def _resolve_resource_limits( + spec: SSHSpecStrict, worker_cfg: WorkerConfig +) -> tuple[float | None, int | None, int | None]: + """Resolve effective CPU/memory limits as min(task spec, worker cap). + + Returns ``(cpu_limit, memory_limit_bytes, pids_limit)``. Each of them may be + ``None`` to mean unbounded — that is, neither the spec nor the cap constrains it. + """ + spec_cpu: float | None = None + spec_mem_bytes: int | None = None + if (res := spec.resources) and (hw := res.hardware): + if hw.cpu is not None: + spec_cpu = float(hw.cpu) + if hw.memory is not None: + if isinstance(hw.memory, str): + spec_mem_bytes = parse_mem_to_bytes(hw.memory) + if spec_mem_bytes is None: + raise ExecutionError( + f"resources.hardware.memory value {hw.memory!r} is not " + "a valid memory string (e.g. '8Gi', '512Mi')" + ) + else: + spec_mem_bytes = int(hw.memory) + + ssh_limits = worker_cfg.ssh_limits + if ssh_limits is None: + return spec_cpu, spec_mem_bytes, None + + cpu_limit = _min_or_none(spec_cpu, ssh_limits.max_cpu_cores) + if ( + spec_cpu is not None + and ssh_limits.max_cpu_cores is not None + and spec_cpu > ssh_limits.max_cpu_cores + ): + logger.warning( + "SSH task requested cpu=%s but worker cap is %s; clamping to cap", + spec_cpu, + ssh_limits.max_cpu_cores, + ) + + memory_limit_bytes = _min_or_none(spec_mem_bytes, ssh_limits.max_memory_bytes) + if ( + spec_mem_bytes is not None + and ssh_limits.max_memory_bytes is not None + and spec_mem_bytes > ssh_limits.max_memory_bytes + ): + logger.warning( + "SSH task requested memory=%d bytes but worker cap is %d; " + "clamping to cap", + spec_mem_bytes, + ssh_limits.max_memory_bytes, + ) + + return cpu_limit, memory_limit_bytes, ssh_limits.max_pids + + +def _min_or_none[T: (int, float)](a: T | None, b: T | None) -> T | None: + if a is None: + return b + if b is None: + return a + return min(a, b) + + +def _resolve_gpu_devices( + spec: SSHSpecStrict, config: WorkerConfig, hardware: WorkerHardware | None +) -> list[str]: + """Pick the smallest subset of the worker's GPUs that satisfies the spec. + + Returns the *host* device IDs to expose to the SSH container. When the spec + sets no GPU constraints at all, returns the worker's full host GPU set; + when only ``type`` or ``memory`` is set without ``count``, defaults to + slicing a single matching device. + """ + host_gpu_ids = [ + d_stripped + for d in os.getenv("WORKER_HOST_GPU_ID", "").split(",") + if (d_stripped := d.strip()) + ] + if not config.enable_ssh_gpu_limit: + return host_gpu_ids + + gpu_req: GPURequirements | None = None + if (res := spec.resources) and (hw := res.hardware): + gpu_req = hw.gpu + if gpu_req is None or ( + gpu_req.count is None and not gpu_req.type and not gpu_req.memory + ): + return host_gpu_ids + + requested = gpu_req.count if gpu_req.count is not None else 1 + if requested <= 0: + return [] + + if not host_gpu_ids: + raise ExecutionError( + f"SSH task requested {requested} GPU(s) but this worker has none" + ) + + # The supervisor passes WORKER_HOST_GPU_ID in the same order as + # worker.hardware.gpu.devices, so positions line up 1:1. When metadata is + # missing or misaligned, fall back to count-only slicing. + devices = hardware.gpu.devices if hardware is not None else [] + if devices and len(devices) != len(host_gpu_ids): + logger.warning( + "WORKER_HOST_GPU_ID (%d) and worker hardware.gpu.devices (%d) " + "disagree; falling back to count-only slicing", + len(host_gpu_ids), + len(devices), ) + devices = [] + + required_mem_bytes: int | None = None + if gpu_req.memory: + required_mem_bytes = parse_gpu_memory_bytes(gpu_req.memory) + if required_mem_bytes is None: + raise ExecutionError( + f"resources.hardware.gpu.memory value {gpu_req.memory!r} is " + "not a valid memory string (e.g. '40Gi', '80GB')" + ) + + if not devices: + # No per-device metadata to filter by — fall back to first-N host IDs. + if len(host_gpu_ids) < requested: + raise ExecutionError( + f"SSH task requested {requested} GPU(s) but only " + f"{len(host_gpu_ids)} are available on this worker" + ) + return host_gpu_ids[:requested] + + matching_indices = select_matching_gpu_indices(devices, gpu_req, limit=requested) + if len(matching_indices) >= requested: + return [host_gpu_ids[idx] for idx in matching_indices] + + # Unified memory fallback + if required_mem_bytes is not None and hardware is not None: + type_only_req = GPURequirements( + count=gpu_req.count, type=gpu_req.type, memory=None + ) + type_matching = select_matching_gpu_indices( + devices, type_only_req, limit=requested + ) + if len(type_matching) >= requested and unified_gpu_memory_satisfies( + hardware, required_mem_bytes, requested + ): + return [host_gpu_ids[idx] for idx in type_matching] + + raise ExecutionError( + f"SSH task requested {requested} GPU(s) matching the spec but " + f"only {len(matching_indices)} satisfying device(s) are available " + "on this worker" + ) class SSHExecutor(Executor): @@ -197,10 +374,10 @@ class SSHExecutor(Executor): name = "ssh" - def __init__( - self, config: WorkerConfig, lifecycle: Lifecycle | None = None - ) -> None: - super().__init__(config, lifecycle) + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + config = self._config + lifecycle = self._lifecycle self._worker_name = ( config.container_name or config.alias @@ -250,7 +427,7 @@ def teardown(self) -> None: def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: spec = self.require_spec(task, SSHSpecStrict) - cfg = SSHConfig.from_spec(spec) + cfg = SSHConfig.from_spec(spec, self._config, self._hardware) access_mode = cfg.access_mode interactive = cfg.interactive @@ -291,9 +468,10 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: mount_plan.staged_input_specs, mount_plan.create_dirs, interactive, + cfg.gpu_device_ids, ) kwargs = self._build_run_kwargs( - cfg.image, + cfg, container_name, environment, labels, @@ -476,6 +654,7 @@ def _build_environment( staged_input_specs: list[tuple[str, str]], create_dirs: list[str], interactive: bool, + gpu_device_ids: list[str] | None = None, ) -> dict[str, str]: env: dict[str, str] = {} if interactive: @@ -484,9 +663,12 @@ def _build_environment( env["AUTHORIZED_KEYS"] = "\n".join(authorized_keys) env["SSH_UID"] = str(os.getuid()) env["SSH_GID"] = str(os.getgid()) - cuda = os.getenv("CUDA_VISIBLE_DEVICES") - if cuda is not None: - env["CUDA_VISIBLE_DEVICES"] = cuda + if gpu_device_ids: + # Docker exposes only the sliced devices, which appear as 0..N-1 + # inside the container regardless of their host IDs. + env["CUDA_VISIBLE_DEVICES"] = ",".join( + str(i) for i in range(len(gpu_device_ids)) + ) if staged_input_specs: env["FLOWMESH_STAGED_INPUT_SPECS"] = "\n".join( f"{mount_path}\t{target_path}" @@ -501,7 +683,7 @@ def _build_environment( def _build_run_kwargs( self, - image: str, + cfg: SSHConfig, container_name: str, environment: dict[str, str], labels: dict[str, str], @@ -511,7 +693,7 @@ def _build_run_kwargs( interactive: bool, ) -> dict[str, Any]: kwargs: dict[str, Any] = { - "image": image, + "image": cfg.image, "name": container_name, "environment": environment, "labels": labels, @@ -525,16 +707,18 @@ def _build_run_kwargs( kwargs["entrypoint"] = [_SSH_RUN_ENTRYPOINT_PATH] if command: kwargs["command"] = command - # WORKER_HOST_GPU_ID holds the real host device IDs assigned by the - # server (e.g. "2,3"). - host_gpu_ids = os.getenv("WORKER_HOST_GPU_ID", "").strip() - if host_gpu_ids: - device_ids = [ - d_stripped for d in host_gpu_ids.split(",") if (d_stripped := d.strip()) - ] + if cfg.cpu_limit is not None: + kwargs["nano_cpus"] = int(cfg.cpu_limit * 1_000_000_000) + if cfg.memory_limit_bytes is not None: + kwargs["mem_limit"] = cfg.memory_limit_bytes + if cfg.pids_limit is not None: + kwargs["pids_limit"] = cfg.pids_limit + if cfg.gpu_device_ids: try: kwargs["device_requests"] = [ - DeviceRequest(device_ids=device_ids, capabilities=[["gpu"]]) + DeviceRequest( + device_ids=list(cfg.gpu_device_ids), capabilities=[["gpu"]] + ) ] if runtime := self._docker_gpu_runtime: kwargs["runtime"] = runtime diff --git a/src/worker/executors/transformers_executor.py b/src/worker/executors/transformers_executor.py index c968c5cd..c4a3b586 100644 --- a/src/worker/executors/transformers_executor.py +++ b/src/worker/executors/transformers_executor.py @@ -61,8 +61,6 @@ EmbeddingSpecStrict, InferenceSpecStrict, ) -from worker.config import WorkerConfig -from worker.lifecycle import Lifecycle from ..utils.logging import configure_hf_library_logging from .base_executor import ExecutionError, Executor, ExecutorTask @@ -122,10 +120,8 @@ class HFTransformersExecutor(InferenceMixin, Executor): name = "transformers" - def __init__( - self, config: WorkerConfig, lifecycle: Lifecycle | None = None - ) -> None: - super().__init__(config, lifecycle) + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) self._tok: PreTrainedTokenizerBase | None = None self._image_processor: Any | None = None self._model: PreTrainedModel | None = None diff --git a/src/worker/executors/vllm_executor.py b/src/worker/executors/vllm_executor.py index 3d5c3d7f..6ad6abd1 100644 --- a/src/worker/executors/vllm_executor.py +++ b/src/worker/executors/vllm_executor.py @@ -68,8 +68,6 @@ from shared.schemas.governance import SpanType from shared.tasks.specs import InferenceSpecStrict -from worker.config import WorkerConfig -from worker.lifecycle import Lifecycle from .base_executor import ExecutionError, Executor, ExecutorTask from .mixins.data import InferenceEntry @@ -122,10 +120,8 @@ class VLLMExecutor(InferenceMixin, Executor): Summary:""" - def __init__( - self, config: WorkerConfig, lifecycle: Lifecycle | None = None - ) -> None: - super().__init__(config, lifecycle) + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) self._llm: LLM | None = None self._model_name: str | None = None self._batched_inputs: list[str | TextPrompt] = [] diff --git a/src/worker/executors/vllm_lora_executor.py b/src/worker/executors/vllm_lora_executor.py index 068e87e4..32fcd6a8 100644 --- a/src/worker/executors/vllm_lora_executor.py +++ b/src/worker/executors/vllm_lora_executor.py @@ -22,8 +22,6 @@ from shared.tasks.components import AdapterConfig from shared.tasks.components.model import AdapterApplyMode from shared.tasks.specs import InferenceSpecStrict -from worker.config import WorkerConfig -from worker.lifecycle import Lifecycle from .base_executor import ExecutionError from .utils.checkpoints import ( @@ -53,10 +51,8 @@ class VLLMLoRAExecutor(VLLMExecutor): name = "vllm_lora" - def __init__( - self, config: WorkerConfig, lifecycle: Lifecycle | None = None - ) -> None: - super().__init__(config, lifecycle) + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) self._adapter_specs: list[LoRAAdapterSpec] = [] self._runtime_specs: list[LoRAAdapterSpec] = [] diff --git a/src/worker/lifecycle.py b/src/worker/lifecycle.py index 6aa8332c..34f03555 100644 --- a/src/worker/lifecycle.py +++ b/src/worker/lifecycle.py @@ -11,6 +11,7 @@ from pathlib import Path from typing import Any +from shared.schemas.worker import SSHLimits from shared.tasks.worker_message import WorkerHardware, WorkerStatus from shared.utils.time import now_iso @@ -70,7 +71,13 @@ def _metrics(self) -> dict[str, Any]: metrics["estimated_energy_kwh"] = energy_total return metrics - def start(self, env: dict[str, Any], hardware: WorkerHardware, tags: list[str]): + def start( + self, + env: dict[str, Any], + hardware: WorkerHardware, + ssh_limits: SSHLimits | None, + tags: list[str], + ): self._started_ts = time.time() try: initial_power = self.power_monitor.sample() @@ -82,6 +89,7 @@ def start(self, env: dict[str, Any], hardware: WorkerHardware, tags: list[str]): pid=os.getpid(), env=env, hardware=hardware, + ssh_limits=ssh_limits, tags=tags, cost_per_hour=self.cost_per_hour, power_metrics=initial_power, diff --git a/src/worker/main.py b/src/worker/main.py index 0b4180b1..8c683431 100644 --- a/src/worker/main.py +++ b/src/worker/main.py @@ -2,6 +2,8 @@ import logging import signal +from shared.tasks.worker_message import WorkerHardware + from .config import WorkerConfig from .executors import EXECUTOR_CLASS_NAMES, EXECUTOR_REGISTRY, IMPORT_ERRORS from .executors.base_executor import Executor @@ -51,6 +53,7 @@ def _parse_args() -> argparse.Namespace: def initialize_executors( config: WorkerConfig, + hardware: WorkerHardware, logger: logging.Logger, lifecycle: Lifecycle, registry: dict[str, type | None] | None = None, @@ -93,8 +96,8 @@ def init_executor(key: str, *, gpu_required: bool = False): try: if key in configured_wrapped: - return MPExecutor(cls, config=config) - return cls(config, lifecycle) + return MPExecutor(cls, config, hardware) + return cls(config, hardware, lifecycle) except Exception as exc: logger.warning("Failed to initialize executor %s: %s", key, exc) return None @@ -190,10 +193,19 @@ def main() -> None: ) hardware = collect_hw(bandwidth_bytes_per_sec=cfg.network_bandwidth_bytes_per_sec) logger.info("Collected hardware info: %s", hardware) - lifecycle.start(env={}, hardware=hardware, tags=cfg.tags) + ssh_limits = cfg.ssh_limits + if ssh_limits is None: + logger.warning( + "SSH resource cap not configured; SSH sessions will be able to access " + "full host resources of this worker." + ) + else: + logger.info("SSH resource cap: %s", ssh_limits.model_dump()) + lifecycle.start(env={}, hardware=hardware, ssh_limits=ssh_limits, tags=cfg.tags) executors, default_executor = initialize_executors( cfg, + hardware, logger, lifecycle, enable_mp_executors=cfg.enable_mp_executors, diff --git a/src/worker/supervisor_client.py b/src/worker/supervisor_client.py index 876ecff1..5c3889b7 100644 --- a/src/worker/supervisor_client.py +++ b/src/worker/supervisor_client.py @@ -16,6 +16,7 @@ from shared.grpc.supervisor.v1 import supervisor_pb2, supervisor_pb2_grpc from shared.schemas.event import Event, TaskEvent, WorkerEvent, serialize_event +from shared.schemas.worker import SSHLimits from shared.tasks.worker_message import ( WorkerHardware, WorkerStatus, @@ -93,6 +94,7 @@ def register( pid: int, env: dict[str, Any], hardware: WorkerHardware, + ssh_limits: SSHLimits | None, tags: list[str], cost_per_hour: float, power_metrics: dict[str, Any] | None = None, @@ -117,6 +119,8 @@ def register( "last_seen": started_at, "cost_per_hour": str(cost_per_hour), } + if ssh_limits is not None: + worker_meta["ssh_limits_json"] = ssh_limits.model_dump_json() self._register_grpc(worker_meta) self.logger.info("Worker connected via supervisor at %s", self.grpc_target) diff --git a/tests/sdk/test_env_schema.py b/tests/sdk/test_env_schema.py index 58a20991..109db432 100644 --- a/tests/sdk/test_env_schema.py +++ b/tests/sdk/test_env_schema.py @@ -1,8 +1,11 @@ +import pytest from flowmesh_stack.env_schema import ( EnvSchema, EnvSection, EnvVar, + EnvVarType, render_env_example, + validate_env_values, ) @@ -40,3 +43,124 @@ def test_render_env_example_ignores_overrides_for_unknown_keys() -> None: body = render_env_example(_toy_schema(), overrides={"NOT_A_KEY": "x"}) assert "NOT_A_KEY" not in body assert "NODE_ROLE=root" in body + + +def _range_schema(var: EnvVar) -> EnvSchema: + return EnvSchema( + name="bounds", header=[], sections=[EnvSection(title="bounds", vars=[var])] + ) + + +class TestValidateRangeBounds: + @pytest.mark.parametrize( + "var,raw,expect_error", + [ + # min_value inclusive (default) + (EnvVar("X", var_type=EnvVarType.INT, min_value=1), "1", False), + (EnvVar("X", var_type=EnvVarType.INT, min_value=1), "0", True), + # min_value exclusive + ( + EnvVar( + "X", + var_type=EnvVarType.INT, + min_value=0, + min_inclusive=False, + ), + "0", + True, + ), + ( + EnvVar( + "X", + var_type=EnvVarType.INT, + min_value=0, + min_inclusive=False, + ), + "1", + False, + ), + # max_value inclusive (default) + (EnvVar("X", var_type=EnvVarType.INT, max_value=10), "10", False), + (EnvVar("X", var_type=EnvVarType.INT, max_value=10), "11", True), + # max_value exclusive + ( + EnvVar( + "X", + var_type=EnvVarType.INT, + max_value=10, + max_inclusive=False, + ), + "10", + True, + ), + ( + EnvVar( + "X", + var_type=EnvVarType.INT, + max_value=10, + max_inclusive=False, + ), + "9", + False, + ), + # FLOAT path mirrors INT + ( + EnvVar( + "X", + var_type=EnvVarType.FLOAT, + min_value=0, + min_inclusive=False, + ), + "0", + True, + ), + ( + EnvVar( + "X", + var_type=EnvVarType.FLOAT, + min_value=0, + min_inclusive=False, + ), + "0.5", + False, + ), + ( + EnvVar( + "X", + var_type=EnvVarType.FLOAT, + max_value=1.5, + max_inclusive=True, + ), + "1.5", + False, + ), + ( + EnvVar( + "X", + var_type=EnvVarType.FLOAT, + max_value=1.5, + max_inclusive=False, + ), + "1.5", + True, + ), + ], + ) + def test_bounds(self, var: EnvVar, raw: str, expect_error: bool) -> None: + errors, _ = validate_env_values(_range_schema(var), {var.key: raw}) + assert bool(errors) is expect_error + + def test_exclusive_error_message_uses_strict_comparator(self) -> None: + var = EnvVar( + "SSH_MAX_CPU", + var_type=EnvVarType.FLOAT, + min_value=0, + min_inclusive=False, + ) + errors, _ = validate_env_values(_range_schema(var), {"SSH_MAX_CPU": "0"}) + assert errors == ["SSH_MAX_CPU must be > 0"] + + def test_inclusive_error_message_uses_default_comparator(self) -> None: + var = EnvVar("PORT", var_type=EnvVarType.INT, min_value=1024) + errors, _ = validate_env_values(_range_schema(var), {"PORT": "80"}) + assert errors == ["PORT must be >= 1024"] diff --git a/tests/sdk/test_schema_compat.py b/tests/sdk/test_schema_compat.py index c600b0a8..b757060f 100644 --- a/tests/sdk/test_schema_compat.py +++ b/tests/sdk/test_schema_compat.py @@ -25,6 +25,7 @@ NodeWorkerInfo, OkResponse, ResultEnvelope, + SSHLimits, StorageInfo, TaskInfo, TaskType, @@ -79,6 +80,7 @@ from server.task.models import TaskInfo as SrvTaskInfo from server.task.models import TaskUsage as SrvTaskUsage from shared.schemas.result import ResultEnvelope as SrvResultEnvelope +from shared.schemas.worker import SSHLimits as SrvSSHLimits from shared.tasks.task_type import TaskType as SrvTaskType from .helpers import assert_enum_members_match, assert_fields_match @@ -113,6 +115,7 @@ (SrvStorageInfo, StorageInfo), (SrvHostInfo, HostInfo), (SrvWorkerHardware, WorkerHardware), + (SrvSSHLimits, SSHLimits), # Logs (SrvLogEvent, LogEvent), (SrvLogEntry, LogEntry), diff --git a/tests/server/registries/test_worker_registry.py b/tests/server/registries/test_worker_registry.py index 88d85d96..b3a8ed62 100644 --- a/tests/server/registries/test_worker_registry.py +++ b/tests/server/registries/test_worker_registry.py @@ -1,7 +1,7 @@ """Tests for worker hardware satisfaction and sorting.""" from server.registries.worker import Worker, hw_satisfies -from shared.schemas.worker import WorkerStatus +from shared.schemas.worker import SSHLimits, WorkerStatus from shared.tasks import TaskEnvelopeStrict from shared.tasks.components.resources import ( GPURequirements, @@ -27,6 +27,7 @@ def _worker( cpu_cores: int = 4, gpu_memory_is_unified: bool = False, gpu_shared_memory_total_bytes: int | None = None, + ssh_limits: SSHLimits | None = None, ) -> Worker: devices = [ GpuInfo(index=i, name=gpu_name, uuid=f"GPU-{i}", memory_total_bytes=gpu_mem) @@ -52,6 +53,7 @@ def _worker( node_alias="g", status=WorkerStatus.IDLE, hardware=hw, + ssh_limits=ssh_limits, ) @@ -161,3 +163,67 @@ def test_cpu_not_satisfied(self) -> None: w = _worker(cpu_cores=2) t = _task(cpu=8) assert hw_satisfies(w, t) is False + + +def _ssh_task(cpu: int | None = None, memory: str | None = None) -> TaskEnvelopeStrict: + hw_req = None + if cpu is not None or memory is not None: + hw_req = HardwareRequirements(cpu=cpu, memory=memory) + resources = ResourcesSpec(hardware=hw_req) if hw_req else None + return TaskEnvelopeStrict.model_validate( + { + "apiVersion": "flowmesh/v1", + "kind": "Task", + "spec": { + "taskType": "ssh", + "interactive": False, + "image": "x", + "command": ["true"], + "resources": resources.model_dump() if resources else None, + }, + } + ) + + +class TestHwSatisfiesSSHLimits: + def test_ssh_cap_below_request_filters_worker(self) -> None: + w = _worker( + cpu_cores=32, + sys_mem=64 * 1024**3, + ssh_limits=SSHLimits(max_cpu_cores=2.0), + ) + t = _ssh_task(cpu=8) + assert hw_satisfies(w, t) is False + + def test_ssh_cap_above_request_passes(self) -> None: + w = _worker( + cpu_cores=32, + sys_mem=64 * 1024**3, + ssh_limits=SSHLimits(max_cpu_cores=16.0), + ) + t = _ssh_task(cpu=8) + assert hw_satisfies(w, t) is True + + def test_ssh_memory_cap_filters(self) -> None: + w = _worker( + cpu_cores=32, + sys_mem=64 * 1024**3, + ssh_limits=SSHLimits(max_memory_bytes=2 * 1024**3), + ) + t = _ssh_task(memory="4Gi") + assert hw_satisfies(w, t) is False + + def test_ssh_cap_ignored_for_non_ssh_tasks(self) -> None: + # Even if ssh_limits would filter out the worker for SSH, non-SSH + # tasks should see the full physical hardware. + w = _worker( + cpu_cores=32, + ssh_limits=SSHLimits(max_cpu_cores=2.0), + ) + t = _task(cpu=8) + assert hw_satisfies(w, t) is True + + def test_no_ssh_cap_behaves_as_before(self) -> None: + w = _worker(cpu_cores=32, sys_mem=64 * 1024**3) + t = _ssh_task(cpu=8, memory="4Gi") + assert hw_satisfies(w, t) is True diff --git a/tests/server/test_docker_ssh_limits.py b/tests/server/test_docker_ssh_limits.py new file mode 100644 index 00000000..de7c8fc5 --- /dev/null +++ b/tests/server/test_docker_ssh_limits.py @@ -0,0 +1,42 @@ +"""Tests for the supervisor Docker adapter's SSH resource cap plumbing.""" + +import pytest + +from server.supervisor.adapters.docker import SSHConfig + + +class TestSSHConfigToEnv: + def test_omits_unset_limits(self) -> None: + env = SSHConfig().to_env() + assert "SSH_MAX_CPU" not in env + assert "SSH_MAX_MEMORY" not in env + assert "SSH_MAX_PIDS" not in env + + def test_emits_set_limits(self) -> None: + env = SSHConfig(max_cpu=4.0, max_memory="8Gi", max_pids=512).to_env() + assert env["SSH_MAX_CPU"] == "4.0" + assert env["SSH_MAX_MEMORY"] == "8Gi" + assert env["SSH_MAX_PIDS"] == "512" + + +class TestSSHConfigToLimits: + def test_returns_none_when_unset(self) -> None: + assert SSHConfig().to_limits() is None + + def test_parses_memory_string(self) -> None: + limits = SSHConfig(max_cpu=2.0, max_memory="4Gi", max_pids=128).to_limits() + assert limits is not None + assert limits.max_cpu_cores == 2.0 + assert limits.max_memory_bytes == 4 * 1024**3 + assert limits.max_pids == 128 + + def test_invalid_memory_raises(self) -> None: + with pytest.raises(ValueError, match="SSH_MAX_MEMORY"): + SSHConfig(max_memory="garbage").to_limits() + + def test_partial_caps(self) -> None: + limits = SSHConfig(max_cpu=1.5).to_limits() + assert limits is not None + assert limits.max_cpu_cores == 1.5 + assert limits.max_memory_bytes is None + assert limits.max_pids is None diff --git a/tests/shared/utils/test_hardware.py b/tests/shared/utils/test_hardware.py new file mode 100644 index 00000000..06d6fb32 --- /dev/null +++ b/tests/shared/utils/test_hardware.py @@ -0,0 +1,259 @@ +"""Tests for the shared GPU requirement helpers.""" + +import pytest + +from shared.tasks.components.resources import GPURequirements +from shared.tasks.worker_message import ( + CPUInfo, + GpuInfo, + GpuPlatformInfo, + MemoryInfo, + NetworkInfo, + WorkerHardware, +) +from shared.utils.hardware import ( + gpu_device_matches, + gpu_type_pattern, + normalize_gpu_type, + parse_gpu_memory_bytes, + select_matching_gpu_indices, + unified_gpu_memory_satisfies, +) + + +class TestNormalizeGpuType: + @pytest.mark.parametrize( + "value,expected", + [ + (None, None), + ("", None), + (" ", None), + ("any", None), + ("AUTO", None), + ("*", None), + ("a100", "a100"), + ("A100", "a100"), + (" A100 ", "a100"), + ], + ) + def test_normalize(self, value: str | None, expected: str | None) -> None: + assert normalize_gpu_type(value) == expected + + +class TestGpuTypePattern: + def test_returns_none_for_wildcard(self) -> None: + assert gpu_type_pattern(None) is None + assert gpu_type_pattern("any") is None + assert gpu_type_pattern("*") is None + + def test_returns_case_insensitive_substring_pattern(self) -> None: + pattern = gpu_type_pattern("A100") + assert pattern is not None + assert pattern.search("NVIDIA A100-SXM4-80GB") is not None + assert pattern.search("nvidia a100") is not None + assert pattern.search("NVIDIA T4") is None + + def test_escapes_special_regex_metachars(self) -> None: + # User-provided strings must be matched literally, not as regex. + pattern = gpu_type_pattern("A100.foo") + assert pattern is not None + assert pattern.search("NVIDIA A100xfoo") is None + assert pattern.search("NVIDIA A100.foo") is not None + + +class TestParseGpuMemoryBytes: + @pytest.mark.parametrize( + "value,expected", + [ + (None, None), + ("40Gi", 40 * 1024**3), + ("512Mi", 512 * 1024**2), + ("80GB", 80 * 1024**3), + (1048576, 1048576), + (1024.0, 1024), + (0, 0), + ], + ) + def test_supported_inputs( + self, value: str | int | float | None, expected: int | None + ) -> None: + assert parse_gpu_memory_bytes(value) == expected + + def test_unparsable_string_returns_none(self) -> None: + assert parse_gpu_memory_bytes("garbage") is None + + +class TestGpuDeviceMatches: + def _device( + self, name: str = "NVIDIA A100-SXM4-80GB", memory_bytes: int = 80 * 1024**3 + ) -> GpuInfo: + return GpuInfo(index=0, name=name, uuid="x", memory_total_bytes=memory_bytes) + + def test_no_constraints_accepts(self) -> None: + assert gpu_device_matches(self._device()) is True + + def test_type_match(self) -> None: + pattern = gpu_type_pattern("A100") + assert gpu_device_matches(self._device(), type_pattern=pattern) is True + + def test_type_mismatch(self) -> None: + pattern = gpu_type_pattern("H100") + assert gpu_device_matches(self._device(), type_pattern=pattern) is False + + def test_memory_meets_floor(self) -> None: + assert ( + gpu_device_matches( + self._device(memory_bytes=80 * 1024**3), + min_memory_bytes=40 * 1024**3, + ) + is True + ) + + def test_memory_below_floor(self) -> None: + assert ( + gpu_device_matches( + self._device(memory_bytes=16 * 1024**3), + min_memory_bytes=40 * 1024**3, + ) + is False + ) + + def test_missing_device_memory_treated_as_zero(self) -> None: + device = GpuInfo(index=0, name="A100", uuid="x", memory_total_bytes=None) + assert gpu_device_matches(device, min_memory_bytes=1) is False + # No memory constraint still accepts even when memory_total_bytes is None. + assert gpu_device_matches(device) is True + + def test_combined_predicates(self) -> None: + pattern = gpu_type_pattern("A100") + # Type matches but memory below floor → reject. + assert ( + gpu_device_matches( + self._device(memory_bytes=40 * 1024**3), + type_pattern=pattern, + min_memory_bytes=80 * 1024**3, + ) + is False + ) + # Both satisfied → accept. + assert ( + gpu_device_matches( + self._device(memory_bytes=80 * 1024**3), + type_pattern=pattern, + min_memory_bytes=40 * 1024**3, + ) + is True + ) + + +class TestSelectMatchingGpuIndices: + def _devices(self) -> list[GpuInfo]: + return [ + GpuInfo( + index=0, name="NVIDIA T4", uuid="t4-0", memory_total_bytes=16 * 1024**3 + ), + GpuInfo( + index=1, + name="NVIDIA A100-SXM4-40GB", + uuid="a100-0", + memory_total_bytes=40 * 1024**3, + ), + GpuInfo( + index=2, + name="NVIDIA A100-SXM4-80GB", + uuid="a100-1", + memory_total_bytes=80 * 1024**3, + ), + GpuInfo( + index=3, + name="NVIDIA A100-SXM4-80GB", + uuid="a100-2", + memory_total_bytes=80 * 1024**3, + ), + ] + + def test_no_constraints_returns_all_indices(self) -> None: + result = select_matching_gpu_indices(self._devices(), GPURequirements()) + assert result == [0, 1, 2, 3] + + def test_type_filter(self) -> None: + result = select_matching_gpu_indices( + self._devices(), GPURequirements(type="A100") + ) + assert result == [1, 2, 3] + + def test_memory_filter(self) -> None: + result = select_matching_gpu_indices( + self._devices(), GPURequirements(memory="80Gi") + ) + assert result == [2, 3] + + def test_per_device_and_semantics(self) -> None: + # An A100-40GB device matches type but not the 80Gi memory floor; it + # must be excluded. This is what makes the helper consistent across + # the dispatcher and the SSH executor. + result = select_matching_gpu_indices( + self._devices(), GPURequirements(type="A100", memory="80Gi") + ) + assert result == [2, 3] + + def test_limit_stops_early(self) -> None: + result = select_matching_gpu_indices( + self._devices(), GPURequirements(type="A100"), limit=2 + ) + assert result == [1, 2] + + def test_limit_zero_returns_empty(self) -> None: + result = select_matching_gpu_indices( + self._devices(), GPURequirements(), limit=0 + ) + assert result == [] + + def test_empty_devices(self) -> None: + assert select_matching_gpu_indices([], GPURequirements(type="A100")) == [] + + +def _unified_hw( + *, + is_unified: bool, + shared_bytes: int | None, +) -> WorkerHardware: + return WorkerHardware( + cpu=CPUInfo(logical_cores=8, model="x"), + memory=MemoryInfo(total_bytes=128 * 1024**3), + gpu=GpuPlatformInfo( + driver_version=None, + cuda_version=None, + devices=[ + GpuInfo( + index=0, name="NVIDIA GB10", uuid="gb10", memory_total_bytes=None + ) + ], + memory_is_unified=is_unified, + shared_memory_total_bytes=shared_bytes, + ), + network=NetworkInfo(ip=None, bandwidth_bytes_per_sec=None), + ) + + +class TestUnifiedGpuMemorySatisfies: + def test_non_unified_returns_false(self) -> None: + hw = _unified_hw(is_unified=False, shared_bytes=128 * 1024**3) + assert unified_gpu_memory_satisfies(hw, 40 * 1024**3, 1) is False + + def test_no_shared_pool_returns_false(self) -> None: + hw = _unified_hw(is_unified=True, shared_bytes=None) + assert unified_gpu_memory_satisfies(hw, 40 * 1024**3, 1) is False + + def test_pool_covers_single_gpu(self) -> None: + hw = _unified_hw(is_unified=True, shared_bytes=128 * 1024**3) + assert unified_gpu_memory_satisfies(hw, 40 * 1024**3, 1) is True + + def test_per_gpu_share_below_request(self) -> None: + # 128 GiB pool / 4 requested = 32 GiB per slot, below the 40 GiB floor. + hw = _unified_hw(is_unified=True, shared_bytes=128 * 1024**3) + assert unified_gpu_memory_satisfies(hw, 40 * 1024**3, 4) is False + + def test_per_gpu_share_meets_request(self) -> None: + hw = _unified_hw(is_unified=True, shared_bytes=128 * 1024**3) + assert unified_gpu_memory_satisfies(hw, 32 * 1024**3, 4) is True diff --git a/tests/worker/factories.py b/tests/worker/factories.py index db2c8c49..10a5ef91 100644 --- a/tests/worker/factories.py +++ b/tests/worker/factories.py @@ -7,7 +7,16 @@ from lumid_hooks import PrincipalContext from shared.tasks import TaskType -from shared.tasks.worker_message import TaskEnvelopeStrict, WorkerTaskMessage +from shared.tasks.worker_message import ( + CPUInfo, + GpuInfo, + GpuPlatformInfo, + MemoryInfo, + NetworkInfo, + TaskEnvelopeStrict, + WorkerHardware, + WorkerTaskMessage, +) from worker.config import WorkerConfig DEFAULT_WORKER_CONFIG: Final[WorkerConfig] = WorkerConfig( @@ -36,6 +45,8 @@ executor_idle_cleanup_sec=None, enable_mp_executors=False, docker_gpu_runtime=None, + ssh_limits=None, + enable_ssh_gpu_limit=False, ) @@ -91,3 +102,17 @@ def make_worker_task_message( task=TaskEnvelopeStrict(apiVersion=api_version, kind=kind, spec=spec), **overrides, ) + + +def make_worker_hardware(devices: list[GpuInfo] | None = None) -> WorkerHardware: + """Build a WorkerHardware with sensible test defaults.""" + return WorkerHardware( + cpu=CPUInfo(logical_cores=2, model="x"), + memory=MemoryInfo(total_bytes=1024**3), + gpu=GpuPlatformInfo( + driver_version=None, + cuda_version=None, + devices=devices or [], + ), + network=NetworkInfo(ip=None, bandwidth_bytes_per_sec=None), + ) diff --git a/tests/worker/test_connector_logging.py b/tests/worker/test_connector_logging.py index 49e42506..beda187e 100644 --- a/tests/worker/test_connector_logging.py +++ b/tests/worker/test_connector_logging.py @@ -6,7 +6,7 @@ from pathlib import Path from shared.tasks.worker_message import WorkerTaskMessage -from tests.worker.factories import make_live_worker_config +from tests.worker.factories import make_live_worker_config, make_worker_hardware from worker.executors.base_executor import Executor from worker.executors.mp_executor import MPExecutor @@ -60,7 +60,11 @@ def test_connector_logs_printed_to_stderr(tmp_path: Path) -> None: by checking container logs or log files in production. """ # Create MP executor with test executor - mp = MPExecutor(ConnectorLoggingExecutor, config=make_live_worker_config(tmp_path)) + mp = MPExecutor( + ConnectorLoggingExecutor, + config=make_live_worker_config(tmp_path), + hardware=make_worker_hardware(), + ) task_payload = WorkerTaskMessage.model_validate( { diff --git a/tests/worker/test_executor_bootstrap.py b/tests/worker/test_executor_bootstrap.py new file mode 100644 index 00000000..34f6596c --- /dev/null +++ b/tests/worker/test_executor_bootstrap.py @@ -0,0 +1,49 @@ +"""Regression tests for executor bootstrap in ``worker.main.initialize_executors``. + +Pins the contract that ``initialize_executors`` constructs every non-MP executor +with ``cls(config, hardware, lifecycle)``. Subclasses are expected to accept this +via ``(*args, **kwargs)`` passthrough so future ``Executor.__init__`` extensions +don't break the chain. +""" + +import logging +from pathlib import Path +from typing import Any + +from tests.worker.factories import make_live_worker_config, make_worker_hardware +from worker.executors.base_executor import Executor, ExecutorTask +from worker.main import initialize_executors + + +class _PassthroughExecutor(Executor): + """Executor that forwards constructor args via the recommended pattern.""" + + name = "passthrough" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: # noqa: D401 + return {"ok": True} + + +class TestInitializeExecutorsHardware: + def test_executor_receives_hardware_via_passthrough(self, tmp_path: Path) -> None: + cfg = make_live_worker_config(tmp_path) + hw = make_worker_hardware() + executors, default = initialize_executors( + config=cfg, + hardware=hw, + logger=logging.getLogger("test"), + lifecycle=None, # type: ignore[arg-type] + registry={"echo": _PassthroughExecutor, "default": _PassthroughExecutor}, + import_errors={}, + cuda_available=False, + enable_mp_executors=False, + ) + # Pre-fix this would silently drop the executor because the subclass + # constructor didn't accept the new positional ``hardware`` arg. + assert isinstance(executors["echo"], _PassthroughExecutor) + assert isinstance(default, _PassthroughExecutor) + assert executors["echo"]._hardware is hw + assert default._hardware is hw diff --git a/tests/worker/test_mp_executor_cleanup_gpu.py b/tests/worker/test_mp_executor_cleanup_gpu.py index e29c444c..0c0f409a 100644 --- a/tests/worker/test_mp_executor_cleanup_gpu.py +++ b/tests/worker/test_mp_executor_cleanup_gpu.py @@ -9,7 +9,7 @@ import pytest from shared.tasks.worker_message import WorkerTaskMessage -from tests.worker.factories import make_live_worker_config +from tests.worker.factories import make_live_worker_config, make_worker_hardware from worker.executors.mp_executor import MPExecutor from worker.executors.vllm_executor import VLLMExecutor @@ -47,7 +47,11 @@ def total_gpu_used() -> int: gpu_before = total_gpu_used() - mp = MPExecutor(VLLMExecutor, config=make_live_worker_config(tmp_path)) + mp = MPExecutor( + VLLMExecutor, + config=make_live_worker_config(tmp_path), + hardware=make_worker_hardware(), + ) # Create task payload matching parse_task_yaml output format task_payload = WorkerTaskMessage.model_validate( diff --git a/tests/worker/test_mp_executor_lifecycle.py b/tests/worker/test_mp_executor_lifecycle.py index 4749ee22..3fe73be8 100644 --- a/tests/worker/test_mp_executor_lifecycle.py +++ b/tests/worker/test_mp_executor_lifecycle.py @@ -7,7 +7,11 @@ from shared.tasks import TaskType from shared.tasks.specs import EchoSpecStrict from shared.tasks.worker_message import WorkerTaskMessage -from tests.worker.factories import make_live_worker_config, make_worker_task_message +from tests.worker.factories import ( + make_live_worker_config, + make_worker_hardware, + make_worker_task_message, +) from worker.executors.base_executor import Executor from worker.executors.mp_executor import MPExecutor @@ -36,7 +40,11 @@ def _simple_task_message() -> WorkerTaskMessage: def test_mp_executor_does_not_start_subprocess_until_first_run(tmp_path: Path) -> None: - mp = MPExecutor(_SimpleMPExecutor, config=make_live_worker_config(tmp_path)) + mp = MPExecutor( + _SimpleMPExecutor, + config=make_live_worker_config(tmp_path), + hardware=make_worker_hardware(), + ) assert mp._shutdown is True assert mp._proc is None @@ -58,7 +66,11 @@ def test_mp_executor_does_not_start_subprocess_until_first_run(tmp_path: Path) - def test_mp_executor_cleanup_before_run_is_noop(tmp_path: Path) -> None: - mp = MPExecutor(_SimpleMPExecutor, config=make_live_worker_config(tmp_path)) + mp = MPExecutor( + _SimpleMPExecutor, + config=make_live_worker_config(tmp_path), + hardware=make_worker_hardware(), + ) mp.cleanup_after_run() diff --git a/tests/worker/test_ssh_executor_noninteractive.py b/tests/worker/test_ssh_executor_noninteractive.py index ea44d73c..27852ce2 100644 --- a/tests/worker/test_ssh_executor_noninteractive.py +++ b/tests/worker/test_ssh_executor_noninteractive.py @@ -12,7 +12,7 @@ import worker.executors.ssh_executor as ssh_executor_module from shared.tasks.specs import SSHSpecStrict from shared.tasks.worker_message import WorkerTaskMessage -from tests.worker.factories import make_live_worker_config +from tests.worker.factories import DEFAULT_WORKER_CONFIG, make_live_worker_config from worker.executors.base_executor import ExecutionError from worker.executors.ssh_executor import ( _SSH_RUN_ENTRYPOINT_PATH, @@ -52,7 +52,7 @@ def _task_message(**spec_updates: object) -> WorkerTaskMessage: class TestSSHConfigFromSpec: def test_noninteractive_config(self) -> None: task = _task_message() - cfg = SSHConfig.from_spec(cast(SSHSpecStrict, task.spec)) + cfg = SSHConfig.from_spec(cast(SSHSpecStrict, task.spec), DEFAULT_WORKER_CONFIG) assert cfg.interactive is False assert cfg.command == ["python", "-c", "print(1)"] assert cfg.image == "python:3.12-slim" @@ -61,7 +61,7 @@ def test_interactive_config_default(self) -> None: spec = SSHSpecStrict.model_validate( {"taskType": "ssh", "authorizedKeys": ["ssh-rsa AAAA..."]} ) - cfg = SSHConfig.from_spec(spec) + cfg = SSHConfig.from_spec(spec, DEFAULT_WORKER_CONFIG) assert cfg.interactive is True assert cfg.command is None assert cfg.entrypoint is None @@ -75,7 +75,7 @@ def test_entrypoint_only(self) -> None: "entrypoint": ["/run.sh"], } ) - cfg = SSHConfig.from_spec(spec) + cfg = SSHConfig.from_spec(spec, DEFAULT_WORKER_CONFIG) assert cfg.interactive is False assert cfg.entrypoint == ["/run.sh"] assert cfg.command is None @@ -89,7 +89,7 @@ def test_command_and_entrypoint(self) -> None: "command": ["echo hello"], } ) - cfg = SSHConfig.from_spec(spec) + cfg = SSHConfig.from_spec(spec, DEFAULT_WORKER_CONFIG) assert cfg.interactive is False assert cfg.entrypoint == ["/bin/bash", "-c"] assert cfg.command == ["echo hello"] @@ -114,7 +114,8 @@ def test_command_only(self, tmp_path: Path) -> None: "image": "python:3.12", "command": ["python", "train.py"], } - ) + ), + DEFAULT_WORKER_CONFIG, ) client = MagicMock() result = executor._resolve_noninteractive_command(client, cfg) @@ -129,7 +130,8 @@ def test_entrypoint_only(self, tmp_path: Path) -> None: "image": "myimg", "entrypoint": ["/run.sh"], } - ) + ), + DEFAULT_WORKER_CONFIG, ) client = MagicMock() result = executor._resolve_noninteractive_command(client, cfg) @@ -145,7 +147,8 @@ def test_entrypoint_and_command(self, tmp_path: Path) -> None: "entrypoint": ["/bin/bash", "-c"], "command": ["echo hello"], } - ) + ), + DEFAULT_WORKER_CONFIG, ) client = MagicMock() result = executor._resolve_noninteractive_command(client, cfg) @@ -160,7 +163,8 @@ def test_neither_inspects_image(self, tmp_path: Path) -> None: "interactive": False, "image": "myimg:latest", } - ) + ), + DEFAULT_WORKER_CONFIG, ) mock_image = MagicMock() mock_image.attrs = { @@ -184,7 +188,8 @@ def test_neither_set_no_image_entrypoint_raises(self, tmp_path: Path) -> None: "interactive": False, "image": "emptyimg", } - ) + ), + DEFAULT_WORKER_CONFIG, ) mock_image = MagicMock() mock_image.attrs = {"Config": {"Entrypoint": None, "Cmd": None}} @@ -234,25 +239,86 @@ def test_noninteractive_omits_ssh_vars(self, tmp_path: Path) -> None: assert env["MY_VAR"] == "val" assert "FLOWMESH_FINISH_SENTINEL" in env + def test_cuda_visible_devices_normalized_to_slice_count( + self, tmp_path: Path + ) -> None: + executor = self._make_executor(tmp_path) + env = executor._build_environment( + "flowmesh", + [], + {}, + [], + [], + interactive=False, + gpu_device_ids=["2", "3", "5"], + ) + assert env["CUDA_VISIBLE_DEVICES"] == "0,1,2" + + def test_cuda_visible_devices_absent_when_no_gpu_slice( + self, tmp_path: Path + ) -> None: + executor = self._make_executor(tmp_path) + env = executor._build_environment( + "flowmesh", + [], + {}, + [], + [], + interactive=False, + ) + assert "CUDA_VISIBLE_DEVICES" not in env + # ------------------------------------------------------------------ # # _build_run_kwargs tests # ------------------------------------------------------------------ # +def _build_ssh_config( + image: str = "myimg:latest", + *, + cpu_limit: float | None = None, + memory_limit_bytes: int | None = None, + pids_limit: int | None = None, + gpu_device_ids: list[str] | None = None, +) -> SSHConfig: + """Construct a minimal SSHConfig for _build_run_kwargs tests.""" + return SSHConfig( + image=image, + interactive=False, + user="flowmesh", + authorized_keys=[], + command=None, + entrypoint=None, + ttl_sec=60.0, + idle_sec=30.0, + access_mode="direct", + extra_env={}, + inputs=[], + output=None, + mounts=[], + poll_interval_sec=1.0, + stop_timeout_sec=5.0, + cpu_limit=cpu_limit, + memory_limit_bytes=memory_limit_bytes, + pids_limit=pids_limit, + gpu_device_ids=gpu_device_ids if gpu_device_ids is not None else [], + ) + + class TestBuildRunKwargs: def _make_executor( self, tmp_path: Path, docker_gpu_runtime: str | None = None ) -> SSHExecutor: cfg = make_live_worker_config(tmp_path, docker_gpu_runtime=docker_gpu_runtime) - return SSHExecutor(cfg, lifecycle=None) + return SSHExecutor(cfg, hardware=None, lifecycle=None) def test_noninteractive_injects_wrapper_entrypoint_and_command( self, tmp_path: Path ) -> None: executor = self._make_executor(tmp_path) kwargs = executor._build_run_kwargs( - image="myimg:latest", + _build_ssh_config(), container_name="worker-1_ssh-task-1234", environment={}, labels={}, @@ -269,7 +335,7 @@ def test_interactive_does_not_override_image_entrypoint( ) -> None: executor = self._make_executor(tmp_path) kwargs = executor._build_run_kwargs( - image="myimg:latest", + _build_ssh_config(), container_name="worker-1_ssh-task-1234", environment={}, labels={}, @@ -286,7 +352,6 @@ def test_gpu_run_kwargs_omit_runtime_when_not_configured( ) -> None: executor = self._make_executor(tmp_path) fake_device_request = MagicMock(name="device_request") - monkeypatch.setenv("WORKER_HOST_GPU_ID", "0") monkeypatch.setattr( ssh_executor_module, "DeviceRequest", @@ -294,7 +359,7 @@ def test_gpu_run_kwargs_omit_runtime_when_not_configured( ) kwargs = executor._build_run_kwargs( - image="myimg:latest", + _build_ssh_config(gpu_device_ids=["0"]), container_name="worker-1_ssh-task-1234", environment={}, labels={}, @@ -312,7 +377,6 @@ def test_gpu_run_kwargs_include_configured_runtime( ) -> None: executor = self._make_executor(tmp_path, docker_gpu_runtime="nvidia") fake_device_request = MagicMock(name="device_request") - monkeypatch.setenv("WORKER_HOST_GPU_ID", "0") monkeypatch.setattr( ssh_executor_module, "DeviceRequest", @@ -320,7 +384,7 @@ def test_gpu_run_kwargs_include_configured_runtime( ) kwargs = executor._build_run_kwargs( - image="myimg:latest", + _build_ssh_config(gpu_device_ids=["0"]), container_name="worker-1_ssh-task-1234", environment={}, labels={}, @@ -333,6 +397,61 @@ def test_gpu_run_kwargs_include_configured_runtime( assert kwargs["device_requests"] == [fake_device_request] assert kwargs["runtime"] == "nvidia" + def test_no_gpu_device_ids_omits_device_requests( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + # Even if WORKER_HOST_GPU_ID is set in env, an empty slice on the + # config (e.g. CPU-only task) must not emit a device_requests kwarg. + monkeypatch.setenv("WORKER_HOST_GPU_ID", "0,1") + executor = self._make_executor(tmp_path) + kwargs = executor._build_run_kwargs( + _build_ssh_config(), + container_name="worker-1_ssh-task-1234", + environment={}, + labels={}, + ports={}, + volumes=[], + command=None, + interactive=False, + ) + assert "device_requests" not in kwargs + + def test_resource_limits_absent_when_unset(self, tmp_path: Path) -> None: + executor = self._make_executor(tmp_path) + kwargs = executor._build_run_kwargs( + _build_ssh_config(), + container_name="worker-1_ssh-task-1234", + environment={}, + labels={}, + ports={}, + volumes=[], + command=None, + interactive=False, + ) + assert "nano_cpus" not in kwargs + assert "mem_limit" not in kwargs + assert "pids_limit" not in kwargs + + def test_resource_limits_applied(self, tmp_path: Path) -> None: + executor = self._make_executor(tmp_path) + kwargs = executor._build_run_kwargs( + _build_ssh_config( + cpu_limit=2.5, + memory_limit_bytes=8 * 1024**3, + pids_limit=256, + ), + container_name="worker-1_ssh-task-1234", + environment={}, + labels={}, + ports={}, + volumes=[], + command=None, + interactive=False, + ) + assert kwargs["nano_cpus"] == int(2.5 * 1_000_000_000) + assert kwargs["mem_limit"] == 8 * 1024**3 + assert kwargs["pids_limit"] == 256 + class TestNoninteractiveContainerStartup: def _make_executor(self, tmp_path: Path) -> SSHExecutor: diff --git a/tests/worker/test_ssh_executor_result_mounting.py b/tests/worker/test_ssh_executor_result_mounting.py index ab84a45c..b7f976e3 100644 --- a/tests/worker/test_ssh_executor_result_mounting.py +++ b/tests/worker/test_ssh_executor_result_mounting.py @@ -9,7 +9,7 @@ from shared.tasks.specs import SSHSpecStrict from shared.tasks.worker_message import WorkerTaskMessage -from tests.worker.factories import make_live_worker_config +from tests.worker.factories import DEFAULT_WORKER_CONFIG, make_live_worker_config from worker.config import WorkerConfig from worker.executors.ssh_executor import ResolvedSSHInput, SSHConfig, SSHExecutor @@ -59,7 +59,7 @@ def test_build_mount_plan_uses_worker_volume_view_in_container( monkeypatch.setenv("RESULTS_DIR", str(results_dir)) task = _task_message() - cfg = SSHConfig.from_spec(cast(SSHSpecStrict, task.spec)) + cfg = SSHConfig.from_spec(cast(SSHSpecStrict, task.spec), DEFAULT_WORKER_CONFIG) executor = SSHExecutor( _worker_config(tmp_path, network_mode="container:flowmesh-worker-1") ) @@ -101,7 +101,7 @@ def test_build_mount_plan_uses_direct_binds_outside_container( monkeypatch.setenv("RESULTS_DIR", str(results_dir)) task = _task_message() - cfg = SSHConfig.from_spec(cast(SSHSpecStrict, task.spec)) + cfg = SSHConfig.from_spec(cast(SSHSpecStrict, task.spec), DEFAULT_WORKER_CONFIG) executor = SSHExecutor(_worker_config(tmp_path, results_mount_source=None)) staged_inputs_dir = tmp_path / "staged-inputs" (staged_inputs_dir / "task-pre").mkdir(parents=True) @@ -140,7 +140,7 @@ def test_resolve_inputs_rejects_unsafe_mount_path( monkeypatch.setenv("RESULTS_DIR", str(results_dir)) task = _task_message(inputs=[{"stage": "preprocess", "mountPath": "/tmp/unsafe"}]) - cfg = SSHConfig.from_spec(cast(SSHSpecStrict, task.spec)) + cfg = SSHConfig.from_spec(cast(SSHSpecStrict, task.spec), DEFAULT_WORKER_CONFIG) executor = SSHExecutor( _worker_config(tmp_path, network_mode="container:flowmesh-worker-1") ) @@ -153,7 +153,7 @@ def test_stage_inputs_locally_downloads_missing_upstream_results( tmp_path: Path, monkeypatch: pytest.MonkeyPatch ) -> None: task = _task_message() - cfg = SSHConfig.from_spec(cast(SSHSpecStrict, task.spec)) + cfg = SSHConfig.from_spec(cast(SSHSpecStrict, task.spec), DEFAULT_WORKER_CONFIG) executor = SSHExecutor(_worker_config(tmp_path, results_mount_source=None)) monkeypatch.setenv("RESULTS_DIR", str(tmp_path / "results")) diff --git a/tests/worker/test_ssh_network_isolation.py b/tests/worker/test_ssh_network_isolation.py index dbfb8a21..314738f9 100644 --- a/tests/worker/test_ssh_network_isolation.py +++ b/tests/worker/test_ssh_network_isolation.py @@ -9,11 +9,35 @@ from tests.worker.factories import make_live_worker_config from worker.config import WorkerConfig -from worker.executors.ssh_executor import SSHExecutor +from worker.executors.ssh_executor import SSHConfig, SSHExecutor _SSH_NETWORK_NAME = "flowmesh_ssh_test" +def _ssh_config(image: str = "myimg:latest") -> SSHConfig: + return SSHConfig( + image=image, + interactive=True, + user="flowmesh", + authorized_keys=[], + command=None, + entrypoint=None, + ttl_sec=60.0, + idle_sec=30.0, + access_mode="direct", + extra_env={}, + inputs=[], + output=None, + mounts=[], + poll_interval_sec=1.0, + stop_timeout_sec=5.0, + cpu_limit=None, + memory_limit_bytes=None, + pids_limit=None, + gpu_device_ids=[], + ) + + def _worker_config( tmp_path: Path, ssh_network_name: str | None = _SSH_NETWORK_NAME ) -> WorkerConfig: @@ -203,7 +227,7 @@ def test_includes_network_when_set(self, tmp_path: Path) -> None: executor._ssh_network = _SSH_NETWORK_NAME kwargs = executor._build_run_kwargs( - image="myimg:latest", + _ssh_config(), container_name="worker-1_ssh-task-1234", environment={}, labels={}, @@ -220,7 +244,7 @@ def test_omits_network_when_none(self, tmp_path: Path) -> None: executor._ssh_network = None kwargs = executor._build_run_kwargs( - image="myimg:latest", + _ssh_config(), container_name="worker-1_ssh-task-1234", environment={}, labels={}, @@ -237,7 +261,7 @@ def test_security_opt_always_present(self, tmp_path: Path) -> None: executor._ssh_network = _SSH_NETWORK_NAME kwargs = executor._build_run_kwargs( - image="myimg:latest", + _ssh_config(), container_name="c", environment={}, labels={}, diff --git a/tests/worker/test_ssh_resource_limits.py b/tests/worker/test_ssh_resource_limits.py new file mode 100644 index 00000000..52c117fc --- /dev/null +++ b/tests/worker/test_ssh_resource_limits.py @@ -0,0 +1,394 @@ +"""Tests for SSH container resource-limit resolution and propagation.""" + +import logging +from typing import Any, cast + +import pytest + +from shared.schemas.worker import SSHLimits +from shared.tasks.specs import SSHSpecStrict +from shared.tasks.worker_message import ( + CPUInfo, + GpuInfo, + GpuPlatformInfo, + MemoryInfo, + NetworkInfo, + WorkerHardware, +) +from tests.worker.factories import make_worker_config, make_worker_hardware +from worker.config import WorkerConfig +from worker.executors.ssh_executor import SSHConfig + + +def _spec(resources: dict[str, object] | None = None) -> SSHSpecStrict: + payload: dict[str, object] = { + "taskType": "ssh", + "interactive": False, + "image": "python:3.12-slim", + "command": ["true"], + } + if resources is not None: + payload["resources"] = resources + return cast(SSHSpecStrict, SSHSpecStrict.model_validate(payload)) + + +class TestSSHConfigResolveLimits: + def test_no_spec_no_cap_yields_unbounded(self) -> None: + cfg = SSHConfig.from_spec(_spec(), make_worker_config()) + assert cfg.cpu_limit is None + assert cfg.memory_limit_bytes is None + assert cfg.pids_limit is None + + def test_spec_only(self) -> None: + cfg = SSHConfig.from_spec( + _spec({"hardware": {"cpu": 2, "memory": "4Gi"}}), + make_worker_config(), + ) + assert cfg.cpu_limit == 2.0 + assert cfg.memory_limit_bytes == 4 * 1024**3 + assert cfg.pids_limit is None + + def test_worker_cap_only(self) -> None: + cfg = SSHConfig.from_spec( + _spec(), + make_worker_config( + ssh_limits=SSHLimits( + max_cpu_cores=1.0, max_memory_bytes=2 * 1024**3, max_pids=128 + ) + ), + ) + assert cfg.cpu_limit == 1.0 + assert cfg.memory_limit_bytes == 2 * 1024**3 + assert cfg.pids_limit == 128 + + def test_spec_below_cap_uses_spec(self) -> None: + cfg = SSHConfig.from_spec( + _spec({"hardware": {"cpu": 1, "memory": "1Gi"}}), + make_worker_config( + ssh_limits=SSHLimits(max_cpu_cores=4.0, max_memory_bytes=8 * 1024**3) + ), + ) + assert cfg.cpu_limit == 1.0 + assert cfg.memory_limit_bytes == 1 * 1024**3 + + def test_spec_above_cap_clamps_and_warns( + self, caplog: pytest.LogCaptureFixture + ) -> None: + caplog.set_level(logging.WARNING, logger="worker.executors.ssh_executor") + cfg = SSHConfig.from_spec( + _spec({"hardware": {"cpu": 8, "memory": "16Gi"}}), + make_worker_config( + ssh_limits=SSHLimits(max_cpu_cores=2.0, max_memory_bytes=4 * 1024**3) + ), + ) + assert cfg.cpu_limit == 2.0 + assert cfg.memory_limit_bytes == 4 * 1024**3 + messages = " ".join(rec.message for rec in caplog.records) + assert "clamping to cap" in messages + + def test_numeric_memory_is_treated_as_bytes(self) -> None: + cfg = SSHConfig.from_spec( + _spec({"hardware": {"memory": 1048576}}), + make_worker_config(), + ) + assert cfg.memory_limit_bytes == 1048576 + + def test_invalid_memory_string_raises(self) -> None: + with pytest.raises(Exception, match="not a valid memory string"): + SSHConfig.from_spec( + _spec({"hardware": {"memory": "lots"}}), + make_worker_config(), + ) + + +def _worker_config_gpu_limit(**overrides: Any) -> WorkerConfig: + return make_worker_config(enable_ssh_gpu_limit=True, **overrides) + + +class TestSSHConfigResolveGpuDevices: + def test_no_host_gpus_yields_empty_slice( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv("WORKER_HOST_GPU_ID", raising=False) + cfg = SSHConfig.from_spec( + _spec({"hardware": {"gpu": {"count": 1}}}), + make_worker_config(), + ) + assert cfg.gpu_device_ids == [] + + def test_disabled_flag_passes_all_host_gpus_despite_spec( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("WORKER_HOST_GPU_ID", "2,3,4,5") + hardware = make_worker_hardware( + [ + GpuInfo( + index=0, name="T4", uuid="t4-0", memory_total_bytes=16 * 1024**3 + ), + GpuInfo( + index=1, name="A100", uuid="a100-0", memory_total_bytes=80 * 1024**3 + ), + GpuInfo( + index=2, name="A100", uuid="a100-1", memory_total_bytes=80 * 1024**3 + ), + GpuInfo( + index=3, name="A100", uuid="a100-2", memory_total_bytes=80 * 1024**3 + ), + ] + ) + cfg = SSHConfig.from_spec( + _spec( + {"hardware": {"gpu": {"count": 1, "type": "A100", "memory": "40Gi"}}} + ), + make_worker_config(), + hardware=hardware, + ) + assert cfg.gpu_device_ids == ["2", "3", "4", "5"] + + def test_disabled_flag_yields_empty_when_no_host_gpus( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv("WORKER_HOST_GPU_ID", raising=False) + cfg = SSHConfig.from_spec(_spec(), make_worker_config()) + assert cfg.gpu_device_ids == [] + + def test_enabled_flag_raises_when_spec_requests_gpus_but_host_has_none( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv("WORKER_HOST_GPU_ID", raising=False) + with pytest.raises(Exception, match="this worker has none"): + SSHConfig.from_spec( + _spec({"hardware": {"gpu": {"count": 1}}}), + _worker_config_gpu_limit(), + ) + + def test_enabled_flag_yields_empty_when_spec_trivial_and_host_empty( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.delenv("WORKER_HOST_GPU_ID", raising=False) + cfg = SSHConfig.from_spec(_spec(), _worker_config_gpu_limit()) + assert cfg.gpu_device_ids == [] + + def test_no_gpu_spec_passes_all_worker_gpus( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("WORKER_HOST_GPU_ID", "2,3") + cfg = SSHConfig.from_spec(_spec(), _worker_config_gpu_limit()) + assert cfg.gpu_device_ids == ["2", "3"] + + def test_count_only_slices_first_n(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("WORKER_HOST_GPU_ID", "2,3,4,5") + cfg = SSHConfig.from_spec( + _spec({"hardware": {"gpu": {"count": 2}}}), + _worker_config_gpu_limit(), + ) + assert cfg.gpu_device_ids == ["2", "3"] + + def test_type_filter_skips_non_matching_devices( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("WORKER_HOST_GPU_ID", "0,1,2") + hardware = make_worker_hardware( + [ + GpuInfo( + index=0, + name="NVIDIA T4", + uuid="t4-0", + memory_total_bytes=16 * 1024**3, + ), + GpuInfo( + index=1, + name="NVIDIA A100-SXM4-80GB", + uuid="a100-0", + memory_total_bytes=80 * 1024**3, + ), + GpuInfo( + index=2, + name="NVIDIA A100-SXM4-80GB", + uuid="a100-1", + memory_total_bytes=80 * 1024**3, + ), + ] + ) + cfg = SSHConfig.from_spec( + _spec({"hardware": {"gpu": {"count": 2, "type": "A100"}}}), + _worker_config_gpu_limit(), + hardware=hardware, + ) + assert cfg.gpu_device_ids == ["1", "2"] + + def test_memory_filter_skips_small_devices( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("WORKER_HOST_GPU_ID", "0,1") + hardware = make_worker_hardware( + [ + GpuInfo( + index=0, name="T4", uuid="t4-0", memory_total_bytes=16 * 1024**3 + ), + GpuInfo( + index=1, name="A100", uuid="a100-0", memory_total_bytes=80 * 1024**3 + ), + ] + ) + cfg = SSHConfig.from_spec( + _spec({"hardware": {"gpu": {"count": 1, "memory": "40Gi"}}}), + _worker_config_gpu_limit(), + hardware=hardware, + ) + assert cfg.gpu_device_ids == ["1"] + + def test_insufficient_matching_devices_raises( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("WORKER_HOST_GPU_ID", "0,1") + hardware = make_worker_hardware( + [ + GpuInfo( + index=0, name="T4", uuid="t4-0", memory_total_bytes=16 * 1024**3 + ), + GpuInfo( + index=1, name="T4", uuid="t4-1", memory_total_bytes=16 * 1024**3 + ), + ] + ) + with pytest.raises(Exception, match="GPU"): + SSHConfig.from_spec( + _spec({"hardware": {"gpu": {"count": 1, "type": "A100"}}}), + _worker_config_gpu_limit(), + hardware=hardware, + ) + + def test_count_zero_yields_empty(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("WORKER_HOST_GPU_ID", "2,3") + cfg = SSHConfig.from_spec( + _spec({"hardware": {"gpu": {"count": 0}}}), + _worker_config_gpu_limit(), + ) + assert cfg.gpu_device_ids == [] + + def test_no_hardware_metadata_still_count_slices( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + # Without WorkerHardware, type / memory filters can't be evaluated; + # count-only slicing still works as a graceful fallback. + monkeypatch.setenv("WORKER_HOST_GPU_ID", "2,3,4") + cfg = SSHConfig.from_spec( + _spec({"hardware": {"gpu": {"count": 2}}}), + _worker_config_gpu_limit(), + ) + assert cfg.gpu_device_ids == ["2", "3"] + + def test_unified_memory_satisfies_floor( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + # GB10 / GH200-style unified-memory worker: per-device memory is + # unreported, but the shared pool covers the requested floor. The slice + # should still go through. + monkeypatch.setenv("WORKER_HOST_GPU_ID", "0") + hardware = WorkerHardware( + cpu=CPUInfo(logical_cores=8, model="x"), + memory=MemoryInfo(total_bytes=128 * 1024**3), + gpu=GpuPlatformInfo( + driver_version=None, + cuda_version=None, + devices=[ + GpuInfo( + index=0, + name="NVIDIA GB10", + uuid="gb10", + memory_total_bytes=None, + ) + ], + memory_is_unified=True, + shared_memory_total_bytes=128 * 1024**3, + ), + network=NetworkInfo(ip=None, bandwidth_bytes_per_sec=None), + ) + cfg = SSHConfig.from_spec( + _spec({"hardware": {"gpu": {"count": 1, "memory": "40Gi"}}}), + _worker_config_gpu_limit(), + hardware=hardware, + ) + assert cfg.gpu_device_ids == ["0"] + + def test_type_filter_applies_when_count_omitted( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + # `count` omitted but `type` set: the dispatcher admits the worker on + # one matching device; the slicer must restrict to that single device, + # not pass through all worker GPUs. + monkeypatch.setenv("WORKER_HOST_GPU_ID", "0,1") + hardware = make_worker_hardware( + [ + GpuInfo( + index=0, + name="NVIDIA T4", + uuid="t4-0", + memory_total_bytes=16 * 1024**3, + ), + GpuInfo( + index=1, + name="NVIDIA A100-SXM4-80GB", + uuid="a100-0", + memory_total_bytes=80 * 1024**3, + ), + ] + ) + cfg = SSHConfig.from_spec( + _spec({"hardware": {"gpu": {"type": "A100"}}}), + _worker_config_gpu_limit(), + hardware=hardware, + ) + assert cfg.gpu_device_ids == ["1"] + + def test_memory_filter_applies_when_count_omitted( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("WORKER_HOST_GPU_ID", "0,1") + hardware = make_worker_hardware( + [ + GpuInfo( + index=0, name="T4", uuid="t4-0", memory_total_bytes=16 * 1024**3 + ), + GpuInfo( + index=1, name="A100", uuid="a100-0", memory_total_bytes=80 * 1024**3 + ), + ] + ) + cfg = SSHConfig.from_spec( + _spec({"hardware": {"gpu": {"memory": "40Gi"}}}), + _worker_config_gpu_limit(), + hardware=hardware, + ) + assert cfg.gpu_device_ids == ["1"] + + def test_unified_memory_pool_too_small_raises( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("WORKER_HOST_GPU_ID", "0") + hardware = WorkerHardware( + cpu=CPUInfo(logical_cores=8, model="x"), + memory=MemoryInfo(total_bytes=16 * 1024**3), + gpu=GpuPlatformInfo( + driver_version=None, + cuda_version=None, + devices=[ + GpuInfo( + index=0, + name="NVIDIA GB10", + uuid="gb10", + memory_total_bytes=None, + ) + ], + memory_is_unified=True, + shared_memory_total_bytes=16 * 1024**3, + ), + network=NetworkInfo(ip=None, bandwidth_bytes_per_sec=None), + ) + with pytest.raises(Exception, match="GPU"): + SSHConfig.from_spec( + _spec({"hardware": {"gpu": {"count": 1, "memory": "40Gi"}}}), + _worker_config_gpu_limit(), + hardware=hardware, + )