Skip to content
Merged
7 changes: 7 additions & 0 deletions cli/stack/src/flowmesh_cli_stack/assets/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ SSH_DEFAULT_IDLE_SEC=
SSH_MAX_TTL_SEC=
SSH_POLL_INTERVAL_SEC=
SSH_STOP_TIMEOUT_SEC=
SSH_MAX_CPU=
SSH_MAX_MEMORY=
SSH_MAX_PIDS=
# Whether to apply requested GPU limits to SSH tasks.
# If false, SSH tasks are allocated all available GPUs
# regardless of their resource requests.
Comment thread
timzsu marked this conversation as resolved.
ENABLE_SSH_GPU_LIMIT=false

# ==== General Settings ====
TZ=Asia/Singapore
Expand Down
18 changes: 18 additions & 0 deletions cli/stack/src/flowmesh_cli_stack/env_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,24 @@
EnvVar("SSH_MAX_TTL_SEC", var_type=EnvVarType.FLOAT, min_value=0),
EnvVar("SSH_POLL_INTERVAL_SEC", var_type=EnvVarType.FLOAT, min_value=0),
EnvVar("SSH_STOP_TIMEOUT_SEC", var_type=EnvVarType.FLOAT, min_value=0),
EnvVar(
"SSH_MAX_CPU",
var_type=EnvVarType.FLOAT,
min_value=0,
min_inclusive=False,
),
EnvVar("SSH_MAX_MEMORY"),
EnvVar("SSH_MAX_PIDS", var_type=EnvVarType.INT, min_value=1),
EnvVar(
"ENABLE_SSH_GPU_LIMIT",
"false",
var_type=EnvVarType.BOOL,
description=[
"Whether to apply requested GPU limits to SSH tasks.",
"If false, SSH tasks are allocated all available GPUs",
"regardless of their resource requests.",
],
),
],
),
EnvSection(
Expand Down
19 changes: 19 additions & 0 deletions docs/ENV.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,22 @@ Spark), set `DOCKER_GPU_RUNTIME=` in the stack env.
| `SUPERVISOR_GRPC_KEEPALIVE_PERMIT_WITHOUT_CALLS` | `true` | gRPC keepalive |
| `SUPERVISOR_GRPC_EXTERNAL_PORT` | – | External port (when port-forwarded) |
| `SERVER_GRPC_TLS_*` | – | TLS certificate files |

## SSH session resource caps

When `enable_ssh` is true on a Docker worker, these configured
ceilings bound every SSH session container spawned by that worker.
Unset values mean unbounded (host-wide access).

| Variable | Default | Description |
|----------|---------|-------------|
| `SSH_MAX_CPU` | – | Max CPU cores per SSH container (float, e.g. `4` or `2.5`). Sets Docker `nano_cpus`. |
| `SSH_MAX_MEMORY` | – | Max memory per SSH container (e.g. `8Gi`, `512Mi`, or a byte count). Sets Docker `mem_limit`. |
| `SSH_MAX_PIDS` | – | Max PIDs per SSH container. Sets Docker `pids_limit`. Admin-only — not user-overridable. |
Comment thread
timzsu marked this conversation as resolved.
| `ENABLE_SSH_GPU_LIMIT` | `false` | When `true`, mount only the GPU subset matching the spec (`count` / `type` / `memory`); otherwise mount all worker GPUs. |

The effective CPU/memory limit is `min(spec.resources.hardware, worker
cap)`. A task that requests more than the worker cap is dispatched to
another worker if one has a larger cap; otherwise the dispatcher
follows its standard requeue/retry behavior. The worker logs a startup
warning if SSH is enabled with no cap configured.
2 changes: 2 additions & 0 deletions sdk/src/flowmesh/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
HostInfo,
MemoryInfo,
NetworkInfo,
SSHLimits,
StorageInfo,
Worker,
WorkerHardware,
Expand Down Expand Up @@ -78,6 +79,7 @@
"NodeWorkerInfo",
"ProfileSummary",
"ResultEnvelope",
"SSHLimits",
"StorageInfo",
"TaskInfo",
"TaskStatus",
Expand Down
7 changes: 7 additions & 0 deletions sdk/src/flowmesh/models/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ class WorkerHardware(BaseModel):
extra: dict[str, Any] | None = None


class SSHLimits(BaseModel):
max_cpu_cores: float | None = None
max_memory_bytes: int | None = None
max_pids: int | None = None


class Worker(BaseModel):
id: str
alias: str | None = None
Expand All @@ -85,6 +91,7 @@ class Worker(BaseModel):
pid: int | None = None
env: dict[str, Any] = Field(default_factory=dict)
hardware: WorkerHardware | None = None
ssh_limits: SSHLimits | None = None
tags: list[str] = Field(default_factory=list)
last_seen: str | None = None
cached_models: list[str] = Field(default_factory=list)
Expand Down
43 changes: 35 additions & 8 deletions sdk/stack/src/flowmesh_stack/env_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Environment schema definitions and pure validation helpers."""

import enum
import operator
from collections.abc import Callable, Iterable, Mapping
from dataclasses import dataclass, field
from logging import _nameToLevel as LOG_LEVELS
Expand Down Expand Up @@ -34,7 +35,9 @@ class EnvVar:
use_default: bool = False
choices: Iterable[str] | None = None
min_value: float | None = None
min_inclusive: bool = True
max_value: float | None = None
max_inclusive: bool = True
min_length: int | None = None
ensure_path: Literal["error", "warn", "create"] | None = None
url_schemes: set[str] | None = None
Expand Down Expand Up @@ -133,19 +136,13 @@ def validate_env_values(
if int_value is None:
errors.append(f"{var.key} must be an integer")
continue
if var.min_value is not None and int_value < var.min_value:
errors.append(f"{var.key} must be >= {int(var.min_value)}")
if var.max_value is not None and int_value > var.max_value:
errors.append(f"{var.key} must be <= {int(var.max_value)}")
_check_value_range(var, int_value, errors)
case EnvVarType.FLOAT:
float_value = parse_float(raw)
if float_value is None:
errors.append(f"{var.key} must be a number")
continue
if var.min_value is not None and float_value < var.min_value:
errors.append(f"{var.key} must be >= {var.min_value}")
if var.max_value is not None and float_value > var.max_value:
errors.append(f"{var.key} must be <= {var.max_value}")
_check_value_range(var, float_value, errors)
case EnvVarType.BOOL:
if parse_bool(raw) is None:
errors.append(
Expand Down Expand Up @@ -213,6 +210,36 @@ def require_all_or_none(
errors.append(f"Either all or none of {', '.join(keys)} must be set")


_COMPARE_OPS = {
"<": operator.lt,
"<=": operator.le,
">": operator.gt,
">=": operator.ge,
}


def _check_value_range(var: EnvVar, value: float, errors: list[str]) -> None:
comp: Literal["<", "<=", ">", ">="]
if var.min_value is not None:
comp = ">=" if var.min_inclusive else ">"
_check_bound(var, value, var.min_value, comp, errors)
if var.max_value is not None:
comp = "<=" if var.max_inclusive else "<"
_check_bound(var, value, var.max_value, comp, errors)


def _check_bound(
var: EnvVar,
value: float,
bound: float,
comparison: Literal["<", "<=", ">", ">="],
errors: list[str],
) -> None:
op = _COMPARE_OPS[comparison]
if not op(value, bound):
errors.append(f"{var.key} must be {comparison} {bound}")


def _ensure_path(raw: str, var: EnvVar, errors: list[str], warnings: list[str]) -> None:
if not raw:
errors.append(f"{var.key} must be a non-empty path")
Expand Down
4 changes: 4 additions & 0 deletions src/server/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@
SSH_MAX_TTL_SEC: float | None = parse_float_env("SSH_MAX_TTL_SEC")
SSH_POLL_INTERVAL_SEC: float | None = parse_float_env("SSH_POLL_INTERVAL_SEC")
SSH_STOP_TIMEOUT_SEC: float | None = parse_float_env("SSH_STOP_TIMEOUT_SEC")
SSH_MAX_CPU: float | None = parse_float_env("SSH_MAX_CPU")
SSH_MAX_MEMORY: str | None = os.getenv("SSH_MAX_MEMORY", "").strip() or None
SSH_MAX_PIDS: int | None = parse_int_env("SSH_MAX_PIDS")
Comment thread
timzsu marked this conversation as resolved.
ENABLE_SSH_GPU_LIMIT: bool = parse_bool_env("ENABLE_SSH_GPU_LIMIT", False)

LOG_FILE: str = os.getenv("LOG_FILE", "server.log")
LOG_MAX_BYTES: int = int(os.getenv("LOG_MAX_BYTES", 5_242_880))
Expand Down
97 changes: 57 additions & 40 deletions src/server/registries/worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import re
from collections.abc import Iterable, Sequence
from typing import Any

Expand All @@ -10,14 +9,22 @@
StopMessage,
TaskMessage,
)
from shared.schemas.worker import SSHLimits
from shared.tasks import TaskEnvelope
from shared.tasks.components.resources import GPURequirements
from shared.tasks.specs import SSHSpecStrict, SSHSpecTemplate
from shared.tasks.worker_message import (
WorkerHardware,
WorkerStatus,
WorkerTaskMessage,
)
from shared.utils import new_worker_id, now_iso, parse_mem_to_bytes
from shared.utils.hardware import (
normalize_gpu_type,
parse_gpu_memory_bytes,
select_matching_gpu_indices,
unified_gpu_memory_satisfies,
)

from ..clients.redis import (
WORKER_EVENT_CHANNEL,
Expand Down Expand Up @@ -48,6 +55,9 @@ class Worker(BaseModel):
hardware: WorkerHardware | None = Field(
default=None, description="Hardware metadata."
)
ssh_limits: SSHLimits | None = Field(
default=None, description="Configured ceiling on SSH session resources."
)
tags: list[str] = Field(default_factory=list, description="Worker tags.")
last_seen: str | None = Field(default=None, description="Last heartbeat timestamp.")
cached_models: list[str] = Field(
Expand Down Expand Up @@ -484,14 +494,34 @@ def hw_satisfies(worker: Worker, task: TaskEnvelope) -> bool:
mem_needed = requirements.memory
gpu_req = requirements.gpu

# Consider SSH hardware limits for SSH tasks.
ssh_caps = (
worker.ssh_limits
if isinstance(task.spec, (SSHSpecStrict, SSHSpecTemplate))
and worker.ssh_limits is not None
else None
)

if cpu_needed is not None:
cpu_cores = None if hw is None else hw.cpu.logical_cores
cpu_cores: float | None = None if hw is None else hw.cpu.logical_cores
if (
ssh_caps is not None
and ssh_caps.max_cpu_cores is not None
and cpu_cores is not None
):
cpu_cores = min(cpu_cores, ssh_caps.max_cpu_cores)
if cpu_cores is None or cpu_cores < cpu_needed:
return False

if mem_needed:
required_bytes = parse_mem_to_bytes(str(mem_needed)) or 0
available = 0 if hw is None else (hw.memory.total_bytes or 0)
if (
ssh_caps is not None
and ssh_caps.max_memory_bytes is not None
and available > 0
):
available = min(available, ssh_caps.max_memory_bytes)
if available < required_bytes:
return False

Expand All @@ -511,34 +541,27 @@ def _gpu_meets_requirements(hw: WorkerHardware, gpu_req: GPURequirements) -> boo
required_count = int(required_count)
except Exception:
required_count = None
required_type = str(gpu_req.type or "").strip().lower()
if required_type in {"", "any", "auto", "*"}:
required_type = ""
required_memory = gpu_req.memory
required_memory_bytes = (
parse_mem_to_bytes(str(required_memory)) if required_memory else None
)
required_type = normalize_gpu_type(gpu_req.type)
required_memory_bytes = parse_gpu_memory_bytes(gpu_req.memory)
needed = required_count or 1

entries = hw.gpu.devices
if entries:
if required_count is not None and len(entries) < required_count:
return False
if required_type:
pattern = re.compile(re.escape(required_type), re.IGNORECASE)
if not any(pattern.search(entry.name) for entry in entries):
return False
if required_memory_bytes:
needed = required_count or 1
eligible = sum(
1
for entry in entries
if (entry.memory_total_bytes or 0) >= required_memory_bytes
)
if eligible < needed and not _unified_gpu_memory_satisfies(
hw, required_memory_bytes, needed
):
return False
return True
if len(select_matching_gpu_indices(entries, gpu_req)) >= needed:
return True
# Unified-memory fallback: when memory is the binding constraint and
# the worker exposes a unified GPU/system pool large enough to cover
# the request, still admit it.
if required_memory_bytes is None:
return False
type_only_req = GPURequirements(
count=gpu_req.count, type=gpu_req.type, memory=None
)
if len(select_matching_gpu_indices(entries, type_only_req)) < needed:
return False
return unified_gpu_memory_satisfies(hw, required_memory_bytes, needed)

# Fallback when workers report aggregate GPU data instead of per-device entries.
count = 0 if hw is None else len(hw.gpu.devices)
Expand All @@ -554,13 +577,12 @@ def _gpu_meets_requirements(hw: WorkerHardware, gpu_req: GPURequirements) -> boo

if required_memory_bytes:
total_mem = 0 if first_gpu is None else (first_gpu.memory_total_bytes or 0)
if total_mem <= 0 and _unified_gpu_memory_satisfies(
hw, required_memory_bytes, required_count or 1
if total_mem <= 0 and unified_gpu_memory_satisfies(
hw, required_memory_bytes, needed
):
return True
if total_mem <= 0:
return False
needed = required_count or 1
per_gpu = total_mem / max(needed, 1)
if per_gpu < required_memory_bytes:
return False
Expand All @@ -577,18 +599,6 @@ def dedicated_gpu_memory_total_bytes(hw: WorkerHardware | None) -> int:
return total


def _unified_gpu_memory_satisfies(
hw: WorkerHardware, required_memory_bytes: int, required_count: int
) -> bool:
if not hw.gpu.memory_is_unified:
return False
shared_total = hw.gpu.shared_memory_total_bytes or 0
if shared_total <= 0:
return False
per_gpu_share = shared_total / max(required_count, 1)
return per_gpu_share >= required_memory_bytes


def _parse_worker_from_redis(
worker_id: str, value: dict[str, Any] | None
) -> Worker | None:
Expand Down Expand Up @@ -621,6 +631,12 @@ def _ensure_str_list(items: Any) -> list[str]:
if hardware_json is None
else WorkerHardware.model_validate_json(hardware_json)
)
ssh_limits_json = value.get("ssh_limits_json")
ssh_limits = (
None
if ssh_limits_json is None
else SSHLimits.model_validate_json(ssh_limits_json)
)
tags = _loads(value.get("tags_json"), [])
cached_models = _ensure_str_list(_loads(value.get("cache_models_json"), []))
cached_datasets = _ensure_str_list(_loads(value.get("cache_datasets_json"), []))
Expand Down Expand Up @@ -650,6 +666,7 @@ def _ensure_str_list(items: Any) -> list[str]:
pid=pid,
env=env,
hardware=hardware,
ssh_limits=ssh_limits,
tags=tags,
last_seen=value.get("last_seen"),
cached_models=cached_models,
Expand Down
Loading