Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions src/flyte/cli/_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def _run_container(
flyte_demo_config_dir: Path,
volume_name: str,
ports: list[str],
gpu: bool = False,
) -> None:
cmd = [
"docker",
Expand All @@ -106,6 +107,8 @@ def _run_container(
"--volume",
f"{volume_name}:/var/lib/flyte/storage",
]
if gpu:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should add some unit tests for these? Also how about an alias devbox

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alias devbox

will change it in a separate PR

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added tests

cmd.extend(["--gpus", "all"])
for port in ports:
cmd.extend(["--publish", port])
cmd.append(image)
Expand Down Expand Up @@ -175,8 +178,10 @@ def _merge_kubeconfig(kubeconfig_path: Path, container_name: str) -> None:

try:
result = _flatten_kubeconfig(default_kubeconfig, kubeconfig_path)
except PermissionError:
# Handle the case that the user does not have permission to kubeconfig file
except (PermissionError, subprocess.CalledProcessError):
# On Linux bind mounts, the in-container kubeconfig lands root-owned on
# the host; kubectl then exits non-zero (CalledProcessError) rather than
# Python raising PermissionError on open.
uid, gid = os.getuid(), os.getgid()
subprocess.run(
["docker", "exec", container_name, "chown", f"{uid}:{gid}", "/.kube/kubeconfig"],
Expand Down Expand Up @@ -222,7 +227,7 @@ def stop_demo() -> None:
console.print("[green]Demo cluster stopped.[/green] Run [bold]flyte start demo[/bold] to resume.")


def launch_demo(image_name: str, is_dev_mode: bool, log_format: str = "console") -> None:
def launch_demo(image_name: str, is_dev_mode: bool, gpu: bool = False, log_format: str = "console") -> None:
_ensure_volume(_VOLUME_NAME)

if _container_is_paused(_CONTAINER_NAME):
Expand All @@ -244,17 +249,17 @@ def launch_demo(image_name: str, is_dev_mode: bool, log_format: str = "console")
steps = _STEPS_DEV if is_dev_mode else _STEPS

if log_format == "json":
_launch_demo_plain(image_name, is_dev_mode, steps)
_launch_demo_plain(image_name, is_dev_mode, steps, gpu=gpu)
else:
_launch_demo_rich(image_name, is_dev_mode, steps)
_launch_demo_rich(image_name, is_dev_mode, steps, gpu=gpu)


def _run_step(step_id: str, image_name: str, is_dev_mode: bool) -> None:
def _run_step(step_id: str, image_name: str, is_dev_mode: bool, gpu: bool = False) -> None:
if step_id == "pull":
_pull_image(image_name)
elif step_id == "start":
_run_container(
image_name, is_dev_mode, _CONTAINER_NAME, _KUBE_DIR, _FLYTE_DEMO_CONFIG_DIR, _VOLUME_NAME, _PORTS
image_name, is_dev_mode, _CONTAINER_NAME, _KUBE_DIR, _FLYTE_DEMO_CONFIG_DIR, _VOLUME_NAME, _PORTS, gpu=gpu
)
elif step_id == "kubeconfig":
_wait_for_kubeconfig(_KUBECONFIG_PATH)
Expand All @@ -266,10 +271,10 @@ def _run_step(step_id: str, image_name: str, is_dev_mode: bool) -> None:
_wait_for_demo_ready(is_dev_mode)


def _launch_demo_plain(image_name: str, is_dev_mode: bool, steps: list[tuple[str, str]]) -> None:
def _launch_demo_plain(image_name: str, is_dev_mode: bool, steps: list[tuple[str, str]], gpu: bool = False) -> None:
for i, (description, step_id) in enumerate(steps, 1):
click.echo(f"[{i}/{len(steps)}] {description}...")
_run_step(step_id, image_name, is_dev_mode)
_run_step(step_id, image_name, is_dev_mode, gpu=gpu)
click.echo(f"[{i}/{len(steps)}] {description}... done")

click.echo("")
Expand All @@ -281,7 +286,7 @@ def _launch_demo_plain(image_name: str, is_dev_mode: bool, steps: list[tuple[str
click.echo(" Image Registry: localhost:30000")


def _launch_demo_rich(image_name: str, is_dev_mode: bool, steps: list[tuple[str, str]]) -> None:
def _launch_demo_rich(image_name: str, is_dev_mode: bool, steps: list[tuple[str, str]], gpu: bool = False) -> None:
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
Expand All @@ -294,7 +299,7 @@ def _launch_demo_rich(image_name: str, is_dev_mode: bool, steps: list[tuple[str,

for description, step_id in steps:
progress.update(overall, description=f"[bold cyan]{description}")
_run_step(step_id, image_name, is_dev_mode)
_run_step(step_id, image_name, is_dev_mode, gpu=gpu)
progress.advance(overall)

if is_dev_mode:
Expand Down
23 changes: 19 additions & 4 deletions src/flyte/cli/_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,15 @@ def tui():
launch_tui_explore()


_DEFAULT_DEMO_IMAGE = "ghcr.io/flyteorg/flyte-demo:nightly"
_DEFAULT_DEMO_GPU_IMAGE = "ghcr.io/flyteorg/flyte-demo:gpu-latest"


@start.command()
@click.option(
"--image",
default="ghcr.io/flyteorg/flyte-demo:nightly",
show_default=True,
default=None,
show_default=f"{_DEFAULT_DEMO_IMAGE} ({_DEFAULT_DEMO_GPU_IMAGE} when --gpu)",
help="Docker image to use for the demo cluster.",
)
@click.option(
Expand All @@ -38,10 +42,21 @@ def tui():
default=False,
help="Enable dev mode inside the demo cluster (sets FLYTE_DEV=True).",
)
@click.option(
"--gpu",
is_flag=True,
default=False,
help="Pass host GPUs into the demo container (adds --gpus all to docker run). "
"Requires an NVIDIA-enabled host. Defaults --image to a GPU-capable image "
"if --image is not explicitly set.",
)
@click.pass_context
def demo(ctx: click.Context, image: str, dev: bool):
def demo(ctx: click.Context, image: str | None, dev: bool, gpu: bool):
"""Start a local Flyte demo cluster."""
from flyte.cli._demo import launch_demo

if image is None:
image = _DEFAULT_DEMO_GPU_IMAGE if gpu else _DEFAULT_DEMO_IMAGE

log_format = getattr(ctx.obj, "log_format", "console") if ctx.obj else "console"
launch_demo(image, dev, log_format=log_format)
launch_demo(image, dev, gpu=gpu, log_format=log_format)
184 changes: 184 additions & 0 deletions tests/cli/test_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""
Unit tests for flyte.cli._demo.

Covers the `--gpu` plumbing on `flyte start demo` and the
kubeconfig chown-retry fallback when kubectl fails to read a root-owned
kubeconfig on Linux bind mounts.
"""

import subprocess
from pathlib import Path
from unittest.mock import MagicMock, patch

import pytest
from click.testing import CliRunner

from flyte.cli._demo import _merge_kubeconfig, _run_container
from flyte.cli._start import demo


class TestRunContainerGpuFlag:
"""Verify the --gpu flag appends `--gpus all` to the docker run command."""

@staticmethod
def _invoke(gpu: bool) -> list[str]:
with patch("flyte.cli._demo.subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0, stderr="")
_run_container(
image="ghcr.io/flyteorg/flyte-demo:gpu-latest",
is_dev_mode=False,
container_name="flyte-demo",
kube_dir=Path("/tmp/.kube"),
flyte_demo_config_dir=Path("/tmp/.flyte/demo"),
volume_name="flyte-demo",
ports=["30080:30080"],
gpu=gpu,
)
assert mock_run.call_count == 1
return mock_run.call_args.args[0]

def test_gpu_flag_appends_gpus_all(self):
cmd = self._invoke(gpu=True)
assert "--gpus" in cmd
assert cmd[cmd.index("--gpus") + 1] == "all"

def test_gpu_disabled_does_not_set_gpus(self):
cmd = self._invoke(gpu=False)
assert "--gpus" not in cmd

def test_gpu_flag_precedes_image(self):
# `docker run [options] <image>` — --gpus must come before the image arg.
cmd = self._invoke(gpu=True)
assert cmd.index("--gpus") < cmd.index("ghcr.io/flyteorg/flyte-demo:gpu-latest")


class TestMergeKubeconfigRetry:
"""Verify the chown-retry fallback for a root-owned kubeconfig on Linux."""

def test_success_on_first_try_does_not_chown(self, tmp_path):
kubeconfig = tmp_path / "kubeconfig"
kubeconfig.write_text("")

with (
patch("flyte.cli._demo._flatten_kubeconfig") as mock_flatten,
patch("flyte.cli._demo.subprocess.run") as mock_run,
patch("flyte.cli._demo.shutil.move", side_effect=lambda src, dst: Path(dst).touch()),
patch("flyte.cli._demo.Path.home", return_value=tmp_path),
):
mock_flatten.return_value = MagicMock(stdout="apiVersion: v1\n")

_merge_kubeconfig(kubeconfig, "flyte-demo")

assert mock_flatten.call_count == 1
mock_run.assert_not_called()

def test_called_process_error_triggers_chown_and_retry(self, tmp_path):
"""This is the bug fix: on Linux, kubectl exits non-zero (CalledProcessError),
not PermissionError. The retry branch must fire."""
kubeconfig = tmp_path / "kubeconfig"
kubeconfig.write_text("")

with (
patch("flyte.cli._demo._flatten_kubeconfig") as mock_flatten,
patch("flyte.cli._demo.subprocess.run") as mock_run,
patch("flyte.cli._demo.shutil.move", side_effect=lambda src, dst: Path(dst).touch()),
patch("flyte.cli._demo.Path.home", return_value=tmp_path),
):
mock_flatten.side_effect = [
subprocess.CalledProcessError(1, ["kubectl", "config", "view", "--flatten"]),
MagicMock(stdout="apiVersion: v1\n"),
]

_merge_kubeconfig(kubeconfig, "flyte-demo")

assert mock_flatten.call_count == 2
assert mock_run.call_count == 1
docker_cmd = mock_run.call_args.args[0]
assert docker_cmd[:4] == ["docker", "exec", "flyte-demo", "chown"]
assert docker_cmd[-1] == "/.kube/kubeconfig"

def test_permission_error_still_triggers_chown_and_retry(self, tmp_path):
"""Legacy path — macOS users opening the file directly — should still work."""
kubeconfig = tmp_path / "kubeconfig"
kubeconfig.write_text("")

with (
patch("flyte.cli._demo._flatten_kubeconfig") as mock_flatten,
patch("flyte.cli._demo.subprocess.run") as mock_run,
patch("flyte.cli._demo.shutil.move", side_effect=lambda src, dst: Path(dst).touch()),
patch("flyte.cli._demo.Path.home", return_value=tmp_path),
):
mock_flatten.side_effect = [
PermissionError("denied"),
MagicMock(stdout="apiVersion: v1\n"),
]

_merge_kubeconfig(kubeconfig, "flyte-demo")

assert mock_flatten.call_count == 2
assert mock_run.call_count == 1

def test_second_flatten_failure_propagates(self, tmp_path):
"""If kubectl still fails after the chown, we should not swallow the error."""
kubeconfig = tmp_path / "kubeconfig"
kubeconfig.write_text("")

with (
patch("flyte.cli._demo._flatten_kubeconfig") as mock_flatten,
patch("flyte.cli._demo.subprocess.run"),
patch("flyte.cli._demo.Path.home", return_value=tmp_path),
):
err = subprocess.CalledProcessError(1, ["kubectl"])
mock_flatten.side_effect = [err, err]

with pytest.raises(subprocess.CalledProcessError):
_merge_kubeconfig(kubeconfig, "flyte-demo")


class TestDemoCliGpuFlag:
"""Verify the --gpu Click option is plumbed to launch_demo."""

def test_gpu_flag_passed_through(self):
runner = CliRunner()
with patch("flyte.cli._demo.launch_demo") as mock_launch:
result = runner.invoke(demo, ["--gpu", "--image", "flyte-demo:gpu-latest"])
assert result.exit_code == 0, result.output
mock_launch.assert_called_once()
assert mock_launch.call_args.kwargs["gpu"] is True

def test_gpu_defaults_to_false(self):
runner = CliRunner()
with patch("flyte.cli._demo.launch_demo") as mock_launch:
result = runner.invoke(demo, ["--image", "flyte-demo:latest"])
assert result.exit_code == 0, result.output
mock_launch.assert_called_once()
assert mock_launch.call_args.kwargs["gpu"] is False


class TestDemoCliDefaultImage:
"""--gpu without --image should pick the GPU-capable default image."""

def test_gpu_without_image_uses_gpu_default(self):
from flyte.cli._start import _DEFAULT_DEMO_GPU_IMAGE

runner = CliRunner()
with patch("flyte.cli._demo.launch_demo") as mock_launch:
result = runner.invoke(demo, ["--gpu"])
assert result.exit_code == 0, result.output
assert mock_launch.call_args.args[0] == _DEFAULT_DEMO_GPU_IMAGE

def test_no_flags_uses_cpu_default(self):
from flyte.cli._start import _DEFAULT_DEMO_IMAGE

runner = CliRunner()
with patch("flyte.cli._demo.launch_demo") as mock_launch:
result = runner.invoke(demo, [])
assert result.exit_code == 0, result.output
assert mock_launch.call_args.args[0] == _DEFAULT_DEMO_IMAGE

def test_explicit_image_with_gpu_is_respected(self):
runner = CliRunner()
with patch("flyte.cli._demo.launch_demo") as mock_launch:
result = runner.invoke(demo, ["--gpu", "--image", "myorg/custom:latest"])
assert result.exit_code == 0, result.output
assert mock_launch.call_args.args[0] == "myorg/custom:latest"
Loading