diff --git a/.github/workflows/python-tests.yaml b/.github/workflows/python-tests.yaml index ef127eb79..9dc9e8e7b 100644 --- a/.github/workflows/python-tests.yaml +++ b/.github/workflows/python-tests.yaml @@ -72,11 +72,11 @@ jobs: sudo apt-get update sudo apt-get install -y qemu-system-arm qemu-system-x86 - - name: Install libgpiod-dev (Linux) + - name: Install libgpiod-dev and fish (Linux) if: runner.os == 'Linux' run: | sudo apt-get update - sudo apt-get install -y libgpiod-dev liblgpio-dev + sudo apt-get install -y libgpiod-dev liblgpio-dev fish - name: Install Renode (Linux) if: runner.os == 'Linux' @@ -100,6 +100,11 @@ jobs: run: | brew install renode/tap/renode + - name: Install fish (macOS) + if: runner.os == 'macOS' + run: | + brew install fish + - name: Cache Fedora Cloud images id: cache-fedora-cloud-images uses: actions/cache@v5 diff --git a/python/packages/jumpstarter-cli-admin/jumpstarter_cli_admin/__init__.py b/python/packages/jumpstarter-cli-admin/jumpstarter_cli_admin/__init__.py index 41eb3f9f3..89ef858b5 100644 --- a/python/packages/jumpstarter-cli-admin/jumpstarter_cli_admin/__init__.py +++ b/python/packages/jumpstarter-cli-admin/jumpstarter_cli_admin/__init__.py @@ -3,6 +3,7 @@ from jumpstarter_cli_common.opt import opt_log_level from jumpstarter_cli_common.version import version +from .completion import completion from .create import create from .delete import delete from .get import get @@ -16,6 +17,7 @@ def admin(): """Jumpstarter Kubernetes cluster admin CLI tool""" +admin.add_command(completion) admin.add_command(get) admin.add_command(create) admin.add_command(delete) diff --git a/python/packages/jumpstarter-cli-admin/jumpstarter_cli_admin/completion.py b/python/packages/jumpstarter-cli-admin/jumpstarter_cli_admin/completion.py new file mode 100644 index 000000000..c6fa73de8 --- /dev/null +++ b/python/packages/jumpstarter-cli-admin/jumpstarter_cli_admin/completion.py @@ -0,0 +1,10 @@ +from jumpstarter_cli_common.completion import make_completion_command + + +def _get_admin(): + from jumpstarter_cli_admin import admin + + return admin + + +completion = make_completion_command(_get_admin, "jmp-admin", "_JMP_ADMIN_COMPLETE") diff --git a/python/packages/jumpstarter-cli-admin/jumpstarter_cli_admin/completion_test.py b/python/packages/jumpstarter-cli-admin/jumpstarter_cli_admin/completion_test.py new file mode 100644 index 000000000..3a7fd03c9 --- /dev/null +++ b/python/packages/jumpstarter-cli-admin/jumpstarter_cli_admin/completion_test.py @@ -0,0 +1,43 @@ +from click.testing import CliRunner + +from . import admin + + +def test_completion_bash_produces_script_with_jmp_admin(): + runner = CliRunner() + result = runner.invoke(admin, ["completion", "bash"]) + assert result.exit_code == 0 + assert len(result.output) > 0 + assert "complete" in result.output.lower() + assert "jmp-admin" in result.output.lower() + + +def test_completion_zsh_produces_compdef_for_jmp_admin(): + runner = CliRunner() + result = runner.invoke(admin, ["completion", "zsh"]) + assert result.exit_code == 0 + assert len(result.output) > 0 + assert "compdef" in result.output.lower() + + +def test_completion_fish_produces_complete_command_for_jmp_admin(): + runner = CliRunner() + result = runner.invoke(admin, ["completion", "fish"]) + assert result.exit_code == 0 + assert len(result.output) > 0 + assert "complete" in result.output.lower() + assert "--command jmp-admin" in result.output.lower() + + +def test_completion_missing_argument_exits_with_error(): + runner = CliRunner() + result = runner.invoke(admin, ["completion"]) + assert result.exit_code == 2 + assert "Missing argument" in result.output or "bash" in result.output + + +def test_completion_unsupported_shell_exits_with_error(): + runner = CliRunner() + result = runner.invoke(admin, ["completion", "powershell"]) + assert result.exit_code == 2 + assert "Invalid value" in result.output or "powershell" in result.output diff --git a/python/packages/jumpstarter-cli-common/jumpstarter_cli_common/completion.py b/python/packages/jumpstarter-cli-common/jumpstarter_cli_common/completion.py new file mode 100644 index 000000000..0f16f62b5 --- /dev/null +++ b/python/packages/jumpstarter-cli-common/jumpstarter_cli_common/completion.py @@ -0,0 +1,19 @@ +from typing import Callable + +import click +from click.shell_completion import get_completion_class + + +def make_completion_command(cli_group_factory: Callable[[], click.Command], prog_name: str, complete_var: str): + @click.command("completion") + @click.argument("shell", type=click.Choice(["bash", "zsh", "fish"])) + def completion(shell: str): + """Generate shell completion script.""" + cli_group = cli_group_factory() + comp_cls = get_completion_class(shell) + if comp_cls is None: + raise click.ClickException(f"Unsupported shell: {shell}") + comp = comp_cls(cli_group, {}, prog_name, complete_var) + click.echo(comp.source()) + + return completion diff --git a/python/packages/jumpstarter-cli-common/jumpstarter_cli_common/completion_test.py b/python/packages/jumpstarter-cli-common/jumpstarter_cli_common/completion_test.py new file mode 100644 index 000000000..7398cebbe --- /dev/null +++ b/python/packages/jumpstarter-cli-common/jumpstarter_cli_common/completion_test.py @@ -0,0 +1,75 @@ +from unittest.mock import patch + +import click +from click.testing import CliRunner + +from .completion import make_completion_command + +PROG_NAME = "testcli" +COMPLETE_VAR = "_TESTCLI_COMPLETE" + + +def _make_test_group(): + @click.group() + def cli(): + pass + + return cli + + +def _make_test_cli_with_completion(): + @click.group() + def cli(): + pass + + cli.add_command(make_completion_command(_make_test_group, PROG_NAME, COMPLETE_VAR)) + return cli + + +def test_completion_bash_produces_completion_script(): + cli = _make_test_cli_with_completion() + runner = CliRunner() + result = runner.invoke(cli, ["completion", "bash"]) + assert result.exit_code == 0 + assert "complete" in result.output.lower() + assert PROG_NAME in result.output.lower() + + +def test_completion_zsh_produces_compdef(): + cli = _make_test_cli_with_completion() + runner = CliRunner() + result = runner.invoke(cli, ["completion", "zsh"]) + assert result.exit_code == 0 + assert "compdef" in result.output.lower() + + +def test_completion_fish_produces_complete_command(): + cli = _make_test_cli_with_completion() + runner = CliRunner() + result = runner.invoke(cli, ["completion", "fish"]) + assert result.exit_code == 0 + assert "complete" in result.output.lower() + assert f"--command {PROG_NAME}" in result.output.lower() + + +def test_completion_missing_argument_exits_with_error(): + cli = _make_test_cli_with_completion() + runner = CliRunner() + result = runner.invoke(cli, ["completion"]) + assert result.exit_code == 2 + + +def test_completion_unsupported_shell_exits_with_error(): + cli = _make_test_cli_with_completion() + runner = CliRunner() + result = runner.invoke(cli, ["completion", "powershell"]) + assert result.exit_code == 2 + + +def test_completion_raises_when_get_completion_class_returns_none(): + with patch("jumpstarter_cli_common.completion.get_completion_class", return_value=None): + cli = _make_test_cli_with_completion() + runner = CliRunner() + result = runner.invoke(cli, ["completion", "bash"]) + assert result.exit_code == 1 + assert "Unsupported shell" in result.output diff --git a/python/packages/jumpstarter-cli/jumpstarter_cli/completion.py b/python/packages/jumpstarter-cli/jumpstarter_cli/completion.py index f21def97f..eab81ed4c 100644 --- a/python/packages/jumpstarter-cli/jumpstarter_cli/completion.py +++ b/python/packages/jumpstarter-cli/jumpstarter_cli/completion.py @@ -1,15 +1,10 @@ -import click -from click.shell_completion import get_completion_class +from jumpstarter_cli_common.completion import make_completion_command -@click.command("completion") -@click.argument("shell", type=click.Choice(["bash", "zsh", "fish"])) -def completion(shell: str): - """Generate shell completion script.""" +def _get_jmp(): from jumpstarter_cli.jmp import jmp - comp_cls = get_completion_class(shell) - if comp_cls is None: - raise click.ClickException(f"Unsupported shell: {shell}") - comp = comp_cls(jmp, {}, "jmp", "_JMP_COMPLETE") - click.echo(comp.source()) + return jmp + + +completion = make_completion_command(_get_jmp, "jmp", "_JMP_COMPLETE") diff --git a/python/packages/jumpstarter-cli/jumpstarter_cli/completion_test.py b/python/packages/jumpstarter-cli/jumpstarter_cli/completion_test.py index f8d74ae05..100517be4 100644 --- a/python/packages/jumpstarter-cli/jumpstarter_cli/completion_test.py +++ b/python/packages/jumpstarter-cli/jumpstarter_cli/completion_test.py @@ -1,5 +1,3 @@ -from unittest.mock import patch - from click.testing import CliRunner from .jmp import jmp @@ -43,11 +41,3 @@ def test_completion_unsupported_shell(): result = runner.invoke(jmp, ["completion", "powershell"]) assert result.exit_code == 2 assert "Invalid value" in result.output or "powershell" in result.output - - -def test_completion_raises_when_get_completion_class_returns_none(): - with patch("jumpstarter_cli.completion.get_completion_class", return_value=None): - runner = CliRunner() - result = runner.invoke(jmp, ["completion", "bash"]) - assert result.exit_code == 1 - assert "Unsupported shell" in result.output diff --git a/python/packages/jumpstarter-cli/jumpstarter_cli/j.py b/python/packages/jumpstarter-cli/jumpstarter_cli/j.py index c077fce51..e9fc93d57 100644 --- a/python/packages/jumpstarter-cli/jumpstarter_cli/j.py +++ b/python/packages/jumpstarter-cli/jumpstarter_cli/j.py @@ -1,12 +1,15 @@ import concurrent.futures._base +import os import sys from contextlib import ExitStack from typing import cast +import anyio import click from anyio import create_task_group, get_cancelled_exc_class, run, to_thread from anyio.from_thread import BlockingPortal from click.exceptions import Exit as ClickExit +from jumpstarter_cli_common.completion import make_completion_command from jumpstarter_cli_common.exceptions import ( ClickExceptionRed, async_handle_exceptions, @@ -19,6 +22,29 @@ from jumpstarter.common.exceptions import EnvironmentVariableNotSetError from jumpstarter.utils.env import env_async +j_completion = make_completion_command(lambda: click.Group("j"), "j", "_J_COMPLETE") + + +_COMPLETION_TIMEOUT_SECONDS = 5 + + +async def _j_shell_complete(): + try: + with anyio.fail_after(_COMPLETION_TIMEOUT_SECONDS): + async with BlockingPortal() as portal: + with ExitStack() as stack: + async with env_async(portal, stack) as client: + + def _run_completion(): + try: + client.cli()(standalone_mode=False) + except SystemExit: + pass + + await to_thread.run_sync(_run_completion, abandon_on_cancel=True) + except TimeoutError: + pass + async def j_async(): @async_handle_exceptions @@ -60,6 +86,12 @@ async def cli(): def j(): traceback.install() + if len(sys.argv) >= 2 and sys.argv[1] == "completion": + j_completion(args=sys.argv[2:]) + return + if "_J_COMPLETE" in os.environ: + run(_j_shell_complete) + return run(j_async) diff --git a/python/packages/jumpstarter-cli/jumpstarter_cli/j_completion_test.py b/python/packages/jumpstarter-cli/jumpstarter_cli/j_completion_test.py new file mode 100644 index 000000000..23b7030c6 --- /dev/null +++ b/python/packages/jumpstarter-cli/jumpstarter_cli/j_completion_test.py @@ -0,0 +1,73 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import anyio +from anyio import run +from click.testing import CliRunner + +from .j import _COMPLETION_TIMEOUT_SECONDS, _j_shell_complete, j_completion + + +def test_j_completion_bash_produces_script(): + runner = CliRunner() + result = runner.invoke(j_completion, ["bash"]) + assert result.exit_code == 0 + assert "complete" in result.output.lower() + assert "_J_COMPLETE" in result.output + + +def test_j_completion_zsh_produces_compdef(): + runner = CliRunner() + result = runner.invoke(j_completion, ["zsh"]) + assert result.exit_code == 0 + assert "compdef" in result.output.lower() + + +def test_j_completion_fish_produces_complete_command(): + runner = CliRunner() + result = runner.invoke(j_completion, ["fish"]) + assert result.exit_code == 0 + assert "complete" in result.output.lower() + assert "--command j" in result.output.lower() + + +def test_j_completion_no_args_exits_with_error(): + runner = CliRunner() + result = runner.invoke(j_completion, []) + assert result.exit_code == 2 + + +def test_j_completion_unsupported_shell_exits_with_error(): + runner = CliRunner() + result = runner.invoke(j_completion, ["powershell"]) + assert result.exit_code == 2 + + +def test_j_shell_complete_handles_system_exit_cleanly(): + mock_cli_group = MagicMock() + mock_cli_group.side_effect = SystemExit(0) + mock_client = MagicMock() + mock_client.cli.return_value = mock_cli_group + + with patch("jumpstarter_cli.j.env_async") as mock_env: + mock_env.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_env.return_value.__aexit__ = AsyncMock(return_value=False) + run(_j_shell_complete) + mock_client.cli.assert_called_once() + mock_cli_group.assert_called_once() + + +def test_j_shell_complete_returns_empty_on_timeout(): + from contextlib import asynccontextmanager + + @asynccontextmanager + async def slow_env(*args, **kwargs): + await anyio.sleep(_COMPLETION_TIMEOUT_SECONDS + 1) + yield MagicMock() + + with patch("jumpstarter_cli.j.env_async", slow_env): + result = run(_j_shell_complete) + assert result is None + + +def test_completion_timeout_is_positive(): + assert _COMPLETION_TIMEOUT_SECONDS > 0 diff --git a/python/packages/jumpstarter-cli/jumpstarter_cli/shell.py b/python/packages/jumpstarter-cli/jumpstarter_cli/shell.py index 47e7952b5..8ded3d837 100644 --- a/python/packages/jumpstarter-cli/jumpstarter_cli/shell.py +++ b/python/packages/jumpstarter-cli/jumpstarter_cli/shell.py @@ -42,7 +42,7 @@ -def _run_shell_only(lease, config, command, path: str) -> int: +def _run_shell_only(lease, config, command, path: str, j_commands: list[str] | None = None) -> int: """Run just the shell command without log streaming.""" allow = config.drivers.allow if config is not None else getattr(lease, "allow", []) unsafe = config.drivers.unsafe if config is not None else getattr(lease, "unsafe", False) @@ -59,6 +59,7 @@ def _run_shell_only(lease, config, command, path: str) -> int: lease=lease, insecure=insecure, passphrase=passphrase, + j_commands=j_commands, ) @@ -324,8 +325,18 @@ async def _run_shell_with_lease_async(lease, exporter_logs, config, command, can warning_text = monitor.status_message[len(HOOK_WARNING_PREFIX) :] click.echo(click.style(f"Warning: {warning_text}", fg="yellow", bold=True)) - # Run the shell command - exit_code = await anyio.to_thread.run_sync(_run_shell_only, lease, config, command, path) + # Extract j command names for static shell completion + j_commands = None + try: + cli_group = client.cli() + if hasattr(cli_group, "list_commands"): + j_commands = cli_group.list_commands(None) + except Exception as e: + logger.debug("Failed to extract j commands for completion: %s", e) + + exit_code = await anyio.to_thread.run_sync( + _run_shell_only, lease, config, command, path, j_commands + ) # Shell has exited. For auto-created leases (release=True), call # EndSession to trigger afterLease hook while keeping log stream diff --git a/python/packages/jumpstarter/jumpstarter/common/utils.py b/python/packages/jumpstarter/jumpstarter/common/utils.py index 7a3e13921..052e65a4b 100644 --- a/python/packages/jumpstarter/jumpstarter/common/utils.py +++ b/python/packages/jumpstarter/jumpstarter/common/utils.py @@ -1,6 +1,10 @@ import os +import re +import shlex +import shutil import signal import sys +import tempfile from contextlib import ExitStack, asynccontextmanager, contextmanager from datetime import timedelta from functools import partial @@ -84,6 +88,178 @@ def _run_process( return process.wait() +_SAFE_COMMAND_NAME = re.compile(r"^[a-zA-Z0-9_-]+$") + + +def _validate_j_commands(j_commands: list[str] | None) -> list[str] | None: + """Filter j_commands to only include safe alphanumeric names.""" + if j_commands is None: + return None + return [cmd for cmd in j_commands if _SAFE_COMMAND_NAME.match(cmd)] + + +def _resolve_cli_paths() -> tuple[str, str, str]: + """Resolve absolute paths for jmp, jmp-admin, and j CLI tools.""" + jmp = shutil.which("jmp") or "jmp" + jmp_admin = shutil.which("jmp-admin") or "jmp-admin" + j = shutil.which("j") or "j" + return jmp, jmp_admin, j + + +def _generate_shell_init(shell_name: str, use_profiles: bool, j_commands: list[str] | None = None) -> str: + """Generate shell-specific init script content for completion and profile sourcing.""" + j_commands = _validate_j_commands(j_commands) + jmp, jmp_admin, j = _resolve_cli_paths() + if shell_name.endswith("bash"): + lines = [] + if use_profiles: + lines.append('[ -f ~/.bashrc ] && source ~/.bashrc') + lines.append(f'eval "$({jmp} completion bash 2>/dev/null)"') + lines.append(f'eval "$({jmp_admin} completion bash 2>/dev/null)"') + if j_commands: + cmds = " ".join(j_commands) + completion_fn = ( + f'_j_completion() {{ [[ ${{COMP_CWORD}} -eq 1 ]]' + f' && COMPREPLY=($(compgen -W "{cmds}" -- "${{COMP_WORDS[COMP_CWORD]}}")); }}' + ) + lines.append(completion_fn) + lines.append("complete -o default -F _j_completion j") + else: + lines.append(f'eval "$({j} completion bash 2>/dev/null)"') + return "\n".join(lines) + "\n" + + elif shell_name.endswith("zsh"): + lines = [] + if use_profiles: + lines.append('[ -f ~/.zshrc ] && source ~/.zshrc') + lines.append("autoload -Uz compinit && compinit") + lines.append(f'eval "$({jmp} completion zsh 2>/dev/null)"') + lines.append(f'eval "$({jmp_admin} completion zsh 2>/dev/null)"') + if j_commands: + cmds = " ".join(j_commands) + lines.append(f"compdef '_arguments \"1:subcommand:({cmds})\"' j") + else: + lines.append(f'eval "$({j} completion zsh 2>/dev/null)"') + return "\n".join(lines) + "\n" + + elif shell_name.endswith("fish"): + lines = [] + lines.append(f"{jmp} completion fish 2>/dev/null | source") + lines.append(f"{jmp_admin} completion fish 2>/dev/null | source") + if j_commands: + for cmd in j_commands: + lines.append(f"complete -c j -f -n '__fish_use_subcommand' -a '{cmd}'") + else: + lines.append(f"{j} completion fish 2>/dev/null | source") + return "\n".join(lines) + "\n" + + return "" + + +def _launch_bash(shell, init_content, use_profiles, common_env, context, lease): + """Launch a bash shell with completion init and custom prompt.""" + env = common_env | { + "_JMP_SHELL_CONTEXT": context, + "PS1": f"{ANSI_GRAY}{PROMPT_CWD} {ANSI_YELLOW}⚡{ANSI_WHITE}{context} {ANSI_YELLOW}➤{ANSI_RESET} ", + } + cmd = [shell] + init_file = None + if init_content: + init_content += ( + f'PS1="{ANSI_GRAY}{PROMPT_CWD} {ANSI_YELLOW}⚡{ANSI_WHITE}' + '$_JMP_SHELL_CONTEXT' + f' {ANSI_YELLOW}➤{ANSI_RESET} "\n' + ) + init_file = tempfile.NamedTemporaryFile(mode="w", suffix=".sh", delete=False) + init_file.write(init_content) + init_file.close() + cmd.extend(["--rcfile", init_file.name]) + elif not use_profiles: + cmd.extend(["--norc", "--noprofile"]) + try: + return _run_process(cmd, env, lease) + finally: + if init_file: + try: + os.unlink(init_file.name) + except OSError: + pass + + +def _launch_fish(shell, init_content, common_env, context, lease): + """Launch a fish shell with completion init and custom prompt.""" + fish_env = common_env | {"_JMP_SHELL_CONTEXT": context} + fish_fn = ( + "function fish_prompt; " + "set_color grey; " + 'printf "%s" (basename $PWD); ' + "set_color yellow; " + 'printf "⚡"; ' + "set_color white; " + 'printf "%s" "$_JMP_SHELL_CONTEXT"; ' + "set_color yellow; " + 'printf "➤ "; ' + "set_color normal; " + "end" + ) + init_cmd = fish_fn + init_file = None + if init_content: + init_file = tempfile.NamedTemporaryFile(mode="w", suffix=".sh", delete=False) + init_file.write(init_content) + init_file.close() + fish_env["_JMP_SHELL_INIT"] = init_file.name + init_cmd += '; source "$_JMP_SHELL_INIT"' + try: + return _run_process([shell, "--init-command", init_cmd], fish_env, lease) + finally: + if init_file: + try: + os.unlink(init_file.name) + except OSError: + pass + + +def _launch_zsh(shell, init_content, common_env, context, lease, use_profiles): + """Launch a zsh shell with completion init, custom prompt, and ZDOTDIR management.""" + env = common_env | { + "_JMP_SHELL_CONTEXT": context, + "PS1": f"%F{{8}}%1~ %F{{yellow}}⚡%F{{white}}{context} %F{{yellow}}➤%f ", + } + if "HISTFILE" not in env: + env["HISTFILE"] = os.path.join(os.path.expanduser("~"), ".zsh_history") + cmd = [shell] + tmpdir = None + if init_content: + init_content += ( + 'PROMPT="%F{8}%1~ %F{yellow}⚡%F{white}' + '${_JMP_SHELL_CONTEXT} %F{yellow}➤%f "\n' + ) + tmpdir = tempfile.mkdtemp() + original_zdotdir = env.get("ZDOTDIR", os.path.expanduser("~")) + original_zshenv = os.path.join(original_zdotdir, ".zshenv") + zshenv_path = os.path.join(tmpdir, ".zshenv") + with open(zshenv_path, "w") as f: + f.write(f"[ -f {shlex.quote(original_zshenv)} ] && source {shlex.quote(original_zshenv)}\n") + zshrc_path = os.path.join(tmpdir, ".zshrc") + with open(zshrc_path, "w") as f: + f.write(f"ZDOTDIR={shlex.quote(original_zdotdir)}\n") + f.write(init_content) + cmd.extend(["--rcs", "-o", "inc_append_history", "-o", "share_history"]) + env["ZDOTDIR"] = tmpdir + else: + if not use_profiles: + cmd.append("--no-rcs") + cmd.extend(["-o", "inc_append_history", "-o", "share_history"]) + try: + return _run_process(cmd, env, lease) + finally: + if tmpdir: + import shutil + + shutil.rmtree(tmpdir, ignore_errors=True) + + def launch_shell( host: str, context: str, @@ -95,29 +271,14 @@ def launch_shell( lease=None, insecure: bool = False, passphrase: str | None = None, + j_commands: list[str] | None = None, ) -> int: - """Launch a shell with a custom prompt indicating the exporter type. - - Args: - host: The jumpstarter host path - context: The context of the shell (e.g. "local" or exporter name) - allow: List of allowed drivers - unsafe: Whether to allow drivers outside of the allow list - use_profiles: Whether to load shell profile files - command: Optional command to run instead of launching an interactive shell - lease: Optional Lease object to set up lease ending callback - - Returns: - The exit code of the shell or command process - """ - shell = os.environ.get("SHELL", "bash") shell_name = os.path.basename(shell) - common_env = os.environ | { JUMPSTARTER_HOST: host, JMP_DRIVERS_ALLOW: "UNSAFE" if unsafe else ",".join(allow), - "_JMP_SUPPRESS_DRIVER_WARNINGS": "1", # Already warned during client initialization + "_JMP_SUPPRESS_DRIVER_WARNINGS": "1", } if insecure: common_env = common_env | {JMP_GRPC_INSECURE: "1"} @@ -127,44 +288,15 @@ def launch_shell( if command: return _run_process(list(command), common_env, lease) - if shell_name.endswith("bash"): - env = common_env | { - "PS1": f"{ANSI_GRAY}{PROMPT_CWD} {ANSI_YELLOW}⚡{ANSI_WHITE}{context} {ANSI_YELLOW}➤{ANSI_RESET} ", - } - cmd = [shell] - if not use_profiles: - cmd.extend(["--norc", "--noprofile"]) - return _run_process(cmd, env, lease) + init_content = _generate_shell_init(shell_name, use_profiles, j_commands) - elif shell_name == "fish": - fish_fn = ( - "function fish_prompt; " - "set_color grey; " - 'printf "%s" (basename $PWD); ' - "set_color yellow; " - 'printf "⚡"; ' - "set_color white; " - f'printf "{context}"; ' - "set_color yellow; " - 'printf "➤ "; ' - "set_color normal; " - "end" - ) - cmd = [shell, "--init-command", fish_fn] - return _run_process(cmd, common_env, lease) + if shell_name.endswith("zsh"): + return _launch_zsh(shell, init_content, common_env, context, lease, use_profiles) - elif shell_name == "zsh": - env = common_env | { - "PS1": f"%F{{8}}%1~ %F{{yellow}}⚡%F{{white}}{context} %F{{yellow}}➤%f ", - } - if "HISTFILE" not in env: - env["HISTFILE"] = os.path.join(os.path.expanduser("~"), ".zsh_history") + if shell_name.endswith("bash"): + return _launch_bash(shell, init_content, use_profiles, common_env, context, lease) - cmd = [shell] - if not use_profiles: - cmd.append("--no-rcs") - cmd.extend(["-o", "inc_append_history", "-o", "share_history"]) - return _run_process(cmd, env, lease) + if shell_name.endswith("fish"): + return _launch_fish(shell, init_content, common_env, context, lease) - else: - return _run_process([shell], common_env, lease) + return _run_process([shell], common_env, lease) diff --git a/python/packages/jumpstarter/jumpstarter/common/utils_test.py b/python/packages/jumpstarter/jumpstarter/common/utils_test.py index 86cd52dfd..cb137ce09 100644 --- a/python/packages/jumpstarter/jumpstarter/common/utils_test.py +++ b/python/packages/jumpstarter/jumpstarter/common/utils_test.py @@ -1,6 +1,21 @@ +import os import shutil +import subprocess +import tempfile +from unittest.mock import patch -from .utils import launch_shell +import pytest + +from .utils import ( + ANSI_GRAY, + ANSI_RESET, + ANSI_WHITE, + ANSI_YELLOW, + PROMPT_CWD, + _generate_shell_init, + _validate_j_commands, + launch_shell, +) def test_launch_shell(tmp_path, monkeypatch): @@ -22,3 +37,581 @@ def test_launch_shell(tmp_path, monkeypatch): use_profiles=False ) assert exit_code == 1 + + +def test_generate_shell_init_uses_absolute_paths_for_completion(monkeypatch): + def fake_which(name): + return f"/usr/bin/{name}" + + monkeypatch.setattr(shutil, "which", fake_which) + + content = _generate_shell_init("zsh", use_profiles=True, j_commands=None) + for line in content.splitlines(): + if "completion zsh" in line and "eval" in line: + dollar_paren = line.split("$(")[1].split(")")[0] + cmd = dollar_paren.split()[0] + assert cmd.startswith("/"), f"Expected absolute path for command, got: {cmd}" + + +def test_generate_bash_init_with_j_commands(): + content = _generate_shell_init("bash", use_profiles=False, j_commands=["power", "serial", "ssh"]) + assert "_j_completion" in content + assert "power serial ssh" in content + assert "jmp completion bash" in content + assert "jmp-admin completion bash" in content + assert ".bashrc" not in content + + +def test_generate_bash_init_with_profiles(): + content = _generate_shell_init("bash", use_profiles=True, j_commands=["power"]) + assert ".bashrc" in content + assert "_j_completion" in content + + +def test_generate_bash_init_without_j_commands(): + content = _generate_shell_init("bash", use_profiles=False, j_commands=None) + assert "j completion bash" in content + assert "_j_completion" not in content + + +def test_generate_zsh_init_with_j_commands(): + content = _generate_shell_init("zsh", use_profiles=False, j_commands=["power", "qemu"]) + assert "jmp completion zsh" in content + assert "compdef" in content + assert "1:subcommand:(power qemu)" in content + + +def test_generate_zsh_init_loads_compinit_before_completions(): + content = _generate_shell_init("zsh", use_profiles=False, j_commands=["power"]) + assert "autoload -Uz compinit && compinit" in content + compinit_pos = content.index("autoload -Uz compinit && compinit") + eval_jmp_pos = content.index("completion zsh") + assert compinit_pos < eval_jmp_pos + + +def test_generate_zsh_init_loads_compinit_before_compdef(): + content = _generate_shell_init("zsh", use_profiles=False, j_commands=["power", "qemu"]) + compinit_pos = content.index("autoload -Uz compinit && compinit") + compdef_pos = content.index("compdef") + assert compinit_pos < compdef_pos + + +def test_generate_zsh_init_without_j_commands_loads_compinit(): + content = _generate_shell_init("zsh", use_profiles=False, j_commands=None) + assert "autoload -Uz compinit && compinit" in content + compinit_pos = content.index("autoload -Uz compinit && compinit") + eval_jmp_pos = content.index("completion zsh") + assert compinit_pos < eval_jmp_pos + + +def test_generate_bash_init_with_profiles_sources_bashrc(): + content = _generate_shell_init("bash", use_profiles=True, j_commands=None) + assert ".bashrc" in content + assert "j completion bash" in content + + +def test_generate_zsh_init_without_j_commands(): + content = _generate_shell_init("zsh", use_profiles=False, j_commands=None) + assert "j completion zsh" in content + assert "compdef" not in content + + +def test_generate_zsh_init_with_profiles_loads_zshrc_before_compinit(): + content = _generate_shell_init("zsh", use_profiles=True, j_commands=["power"]) + assert ".zshrc" in content + compinit_pos = content.index("autoload -Uz compinit && compinit") + zshrc_pos = content.index(".zshrc") + assert zshrc_pos < compinit_pos + + +def test_generate_fish_init_with_j_commands(): + content = _generate_shell_init("fish", use_profiles=False, j_commands=["power", "qemu"]) + assert "'power'" in content + assert "'qemu'" in content + assert "jmp completion fish" in content + + +def test_generate_fish_init_without_j_commands(): + content = _generate_shell_init("fish", use_profiles=False, j_commands=None) + assert "j completion fish" in content + + +def test_generate_shell_init_unknown_shell(): + content = _generate_shell_init("csh", use_profiles=False, j_commands=["power"]) + assert content == "" + + +def test_launch_shell_with_j_commands(tmp_path, monkeypatch): + monkeypatch.setenv("SHELL", shutil.which("true")) + exit_code = launch_shell( + host=str(tmp_path / "test.sock"), + context="remote", + allow=["*"], + unsafe=False, + use_profiles=False, + j_commands=["power", "serial"], + ) + assert exit_code == 0 + + +def test_validate_j_commands_filters_unsafe_names(): + assert _validate_j_commands(None) is None + assert _validate_j_commands(["power", "serial"]) == ["power", "serial"] + assert _validate_j_commands(["good-cmd", "good_cmd"]) == ["good-cmd", "good_cmd"] + assert _validate_j_commands(["$(evil)", "power"]) == ["power"] + assert _validate_j_commands(["bad;cmd", "ok"]) == ["ok"] + assert _validate_j_commands(["bad cmd", "ok"]) == ["ok"] + assert _validate_j_commands(['"injection', "ok"]) == ["ok"] + + +def test_generate_shell_init_excludes_unsafe_j_commands(): + content = _generate_shell_init("bash", use_profiles=False, j_commands=["power", "$(evil)", "serial"]) + assert "power" in content + assert "serial" in content + assert "$(evil)" not in content + + +def test_launch_shell_zsh_cleans_up_all_temp_files(tmp_path, monkeypatch): + monkeypatch.setenv("SHELL", "/usr/bin/zsh") + zshrc_paths = [] + + def mock_run_process(cmd, env, lease=None): + zdotdir = env.get("ZDOTDIR") + if zdotdir: + zshrc = os.path.join(zdotdir, ".zshrc") + zshrc_paths.append(zshrc) + assert os.path.exists(zshrc) + return 0 + + with patch("jumpstarter.common.utils._run_process", mock_run_process): + exit_code = launch_shell( + host=str(tmp_path / "test.sock"), + context="remote", + allow=["*"], + unsafe=False, + use_profiles=False, + j_commands=["power", "serial"], + ) + assert exit_code == 0 + + assert len(zshrc_paths) == 1 + assert not os.path.exists(zshrc_paths[0]) + + +def test_launch_fish_cleans_up_temp_init_file(tmp_path, monkeypatch): + monkeypatch.setenv("SHELL", "/usr/bin/fish") + init_file_paths = [] + + def mock_run_process(cmd, env, lease=None): + init_path = env.get("_JMP_SHELL_INIT") + if init_path: + init_file_paths.append(init_path) + assert os.path.exists(init_path), "init file must exist during process run" + return 0 + + with patch("jumpstarter.common.utils._run_process", mock_run_process): + exit_code = launch_shell( + host=str(tmp_path / "test.sock"), + context="remote", + allow=["*"], + unsafe=False, + use_profiles=False, + j_commands=["power", "serial"], + ) + + assert exit_code == 0 + assert len(init_file_paths) == 1 + assert not os.path.exists(init_file_paths[0]), "init file must be cleaned up after process exits" + + +def test_launch_fish_passes_context_via_env(tmp_path, monkeypatch): + monkeypatch.setenv("SHELL", "/usr/bin/fish") + captured_env = {} + captured_cmd = [] + + def mock_run_process(cmd, env, lease=None): + captured_env.update(env) + captured_cmd.extend(cmd) + return 0 + + context = "test-context" + with patch("jumpstarter.common.utils._run_process", mock_run_process): + launch_shell( + host=str(tmp_path / "test.sock"), + context=context, + allow=["*"], + unsafe=False, + use_profiles=False, + ) + + assert captured_env.get("_JMP_SHELL_CONTEXT") == context + init_cmd_arg = captured_cmd[captured_cmd.index("--init-command") + 1] + assert context not in init_cmd_arg + + +def test_launch_fish_passes_init_file_via_env(tmp_path, monkeypatch): + monkeypatch.setenv("SHELL", "/usr/bin/fish") + captured_env = {} + captured_cmd = [] + + def mock_run_process(cmd, env, lease=None): + captured_env.update(env) + captured_cmd.extend(cmd) + return 0 + + with patch("jumpstarter.common.utils._run_process", mock_run_process): + launch_shell( + host=str(tmp_path / "test.sock"), + context="remote", + allow=["*"], + unsafe=False, + use_profiles=False, + j_commands=["power"], + ) + + assert "_JMP_SHELL_INIT" in captured_env + init_cmd_arg = captured_cmd[captured_cmd.index("--init-command") + 1] + assert captured_env["_JMP_SHELL_INIT"] not in init_cmd_arg + + +def test_generate_bash_init_limits_completion_to_first_arg(): + content = _generate_shell_init("bash", use_profiles=False, j_commands=["power", "serial"]) + assert "COMP_CWORD" in content + assert "-eq 1" in content + + +def test_launch_shell_zsh_restores_zdotdir(tmp_path, monkeypatch): + monkeypatch.setenv("SHELL", "/usr/bin/zsh") + home_dir = os.path.expanduser("~") + + def mock_run_process(cmd, env, lease=None): + zdotdir = env.get("ZDOTDIR") + if zdotdir: + zshrc = os.path.join(zdotdir, ".zshrc") + with open(zshrc) as f: + first_line = f.readline().strip() + assert "ZDOTDIR=" in first_line + assert home_dir in first_line + return 0 + + with patch("jumpstarter.common.utils._run_process", mock_run_process): + launch_shell( + host=str(tmp_path / "test.sock"), + context="remote", + allow=["*"], + unsafe=False, + use_profiles=False, + j_commands=["power"], + ) + + +def test_launch_shell_zsh_uses_tmpdir_with_zshrc_and_zshenv(tmp_path, monkeypatch): + monkeypatch.setenv("SHELL", "/usr/bin/zsh") + temp_dirs = [] + + def mock_run_process(cmd, env, lease=None): + zdotdir = env.get("ZDOTDIR") + if zdotdir: + temp_dirs.append(zdotdir) + entries = sorted(os.listdir(zdotdir)) + assert entries == [".zshenv", ".zshrc"], f"Expected .zshenv and .zshrc in ZDOTDIR, found: {entries}" + return 0 + + with patch("jumpstarter.common.utils._run_process", mock_run_process): + launch_shell( + host=str(tmp_path / "test.sock"), + context="remote", + allow=["*"], + unsafe=False, + use_profiles=False, + j_commands=["power"], + ) + + assert len(temp_dirs) == 1 + assert not os.path.exists(temp_dirs[0]) + + +def test_launch_shell_zsh_sources_original_zshenv(tmp_path, monkeypatch): + monkeypatch.setenv("SHELL", "/usr/bin/zsh") + home_dir = os.path.expanduser("~") + original_zshenv = os.path.join(home_dir, ".zshenv") + + def mock_run_process(cmd, env, lease=None): + zdotdir = env.get("ZDOTDIR") + if zdotdir: + zshenv_path = os.path.join(zdotdir, ".zshenv") + assert os.path.exists(zshenv_path), ".zshenv must exist in temp ZDOTDIR" + with open(zshenv_path) as f: + content = f.read() + assert original_zshenv in content, ( + f".zshenv must source original {original_zshenv}" + ) + return 0 + + with patch("jumpstarter.common.utils._run_process", mock_run_process): + launch_shell( + host=str(tmp_path / "test.sock"), + context="remote", + allow=["*"], + unsafe=False, + use_profiles=False, + j_commands=["power"], + ) + + +@pytest.mark.skipif(not shutil.which("zsh"), reason="zsh not installed") +def test_zsh_init_does_not_produce_compdef_errors(): + init_content = _generate_shell_init("zsh", use_profiles=False, j_commands=["power", "serial"]) + with tempfile.NamedTemporaryFile(mode="w", suffix=".zsh", delete=False) as f: + f.write(init_content) + init_file = f.name + try: + result = subprocess.run( + ["zsh", "-c", f"source {init_file}; exit 0"], + env={"HOME": "/nonexistent", "PATH": os.environ.get("PATH", "")}, + capture_output=True, + text=True, + timeout=10, + ) + assert "command not found: compdef" not in result.stderr + assert result.returncode == 0 + finally: + os.unlink(init_file) + + +@pytest.mark.skipif(not shutil.which("bash"), reason="bash not installed") +def test_bash_init_produces_no_errors(): + init_content = _generate_shell_init("bash", use_profiles=False, j_commands=["power", "serial"]) + with tempfile.NamedTemporaryFile(mode="w", suffix=".sh", delete=False) as f: + f.write(init_content) + rcfile = f.name + try: + result = subprocess.run( + ["bash", "-c", f"source {rcfile}; exit 0"], + env={"HOME": "/nonexistent", "PATH": os.environ.get("PATH", "")}, + capture_output=True, + text=True, + timeout=10, + ) + assert "command not found" not in result.stderr + assert result.returncode == 0 + finally: + os.unlink(rcfile) + + +@pytest.mark.skipif(not shutil.which("fish"), reason="fish not installed") +def test_fish_init_produces_no_errors(): + init_content = _generate_shell_init("fish", use_profiles=False, j_commands=["power", "serial"]) + result = subprocess.run( + ["fish", "--init-command", init_content, "-c", "exit 0"], + env={"HOME": "/nonexistent", "PATH": os.environ.get("PATH", "")}, + capture_output=True, + text=True, + timeout=10, + ) + assert "command not found" not in result.stderr + assert result.returncode == 0 + + +def test_launch_zsh_sets_prompt_after_profile_in_init(tmp_path, monkeypatch): + monkeypatch.setenv("SHELL", "/usr/bin/zsh") + captured_zshrc = [] + + def mock_run_process(cmd, env, lease=None): + zdotdir = env.get("ZDOTDIR") + if zdotdir: + zshrc = os.path.join(zdotdir, ".zshrc") + with open(zshrc) as f: + captured_zshrc.append(f.read()) + return 0 + + with patch("jumpstarter.common.utils._run_process", mock_run_process): + launch_shell( + host=str(tmp_path / "test.sock"), + context="test-device", + allow=["*"], + unsafe=False, + use_profiles=True, + j_commands=["power"], + ) + + assert len(captured_zshrc) == 1 + content = captured_zshrc[0] + assert "PROMPT=" in content + zshrc_pos = content.index(".zshrc") + prompt_pos = content.index("PROMPT=") + assert prompt_pos > zshrc_pos + + +def test_launch_zsh_passes_context_via_env(tmp_path, monkeypatch): + monkeypatch.setenv("SHELL", "/usr/bin/zsh") + captured_env = {} + + def mock_run_process(cmd, env, lease=None): + captured_env.update(env) + return 0 + + with patch("jumpstarter.common.utils._run_process", mock_run_process): + launch_shell( + host=str(tmp_path / "test.sock"), + context="test-device", + allow=["*"], + unsafe=False, + use_profiles=False, + ) + + assert captured_env.get("_JMP_SHELL_CONTEXT") == "test-device" + + +def test_launch_zsh_prompt_references_env_var_not_literal_context(tmp_path, monkeypatch): + monkeypatch.setenv("SHELL", "/usr/bin/zsh") + captured_zshrc = [] + + def mock_run_process(cmd, env, lease=None): + zdotdir = env.get("ZDOTDIR") + if zdotdir: + zshrc = os.path.join(zdotdir, ".zshrc") + with open(zshrc) as f: + captured_zshrc.append(f.read()) + return 0 + + with patch("jumpstarter.common.utils._run_process", mock_run_process): + launch_shell( + host=str(tmp_path / "test.sock"), + context="test-device-name", + allow=["*"], + unsafe=False, + use_profiles=False, + j_commands=["power"], + ) + + content = captured_zshrc[0] + prompt_line = [line for line in content.split("\n") if "PROMPT=" in line][0] + assert "${_JMP_SHELL_CONTEXT}" in prompt_line + assert "test-device-name" not in prompt_line + + +def test_launch_bash_sets_prompt_after_profile_in_init(tmp_path, monkeypatch): + monkeypatch.setenv("SHELL", "/usr/bin/bash") + captured_content = [] + + def mock_run_process(cmd, env, lease=None): + if "--rcfile" in cmd: + rcfile = cmd[cmd.index("--rcfile") + 1] + with open(rcfile) as f: + captured_content.append(f.read()) + return 0 + + with patch("jumpstarter.common.utils._run_process", mock_run_process): + launch_shell( + host=str(tmp_path / "test.sock"), + context="test-device", + allow=["*"], + unsafe=False, + use_profiles=True, + j_commands=["power"], + ) + + assert len(captured_content) == 1 + content = captured_content[0] + assert "PS1=" in content + bashrc_pos = content.index(".bashrc") + ps1_pos = content.index("PS1=") + assert ps1_pos > bashrc_pos + + +def test_launch_bash_passes_context_via_env(tmp_path, monkeypatch): + monkeypatch.setenv("SHELL", "/usr/bin/bash") + captured_env = {} + + def mock_run_process(cmd, env, lease=None): + captured_env.update(env) + return 0 + + with patch("jumpstarter.common.utils._run_process", mock_run_process): + launch_shell( + host=str(tmp_path / "test.sock"), + context="test-device", + allow=["*"], + unsafe=False, + use_profiles=False, + ) + + assert captured_env.get("_JMP_SHELL_CONTEXT") == "test-device" + + +@pytest.mark.skipif(not shutil.which("zsh"), reason="zsh not installed") +def test_zsh_prompt_survives_user_profile_override(): + home_dir = tempfile.mkdtemp() + try: + with open(os.path.join(home_dir, ".zshrc"), "w") as f: + f.write('PROMPT="user-prompt> "\n') + + init_content = _generate_shell_init("zsh", use_profiles=True, j_commands=["power"]) + init_content += ( + 'PROMPT="%F{8}%1~ %F{yellow}⚡%F{white}' + '${_JMP_SHELL_CONTEXT} %F{yellow}➤%f "\n' + ) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".zsh", delete=False) as f: + f.write(init_content) + init_file = f.name + + try: + result = subprocess.run( + ["zsh", "-c", f"source {init_file}; echo \"$PROMPT\""], + env={ + "HOME": home_dir, + "PATH": os.environ.get("PATH", ""), + "_JMP_SHELL_CONTEXT": "test-device", + }, + capture_output=True, + text=True, + timeout=10, + ) + assert result.returncode == 0, f"zsh failed: {result.stderr}" + assert "user-prompt" not in result.stdout + assert "test-device" in result.stdout + finally: + os.unlink(init_file) + finally: + shutil.rmtree(home_dir, ignore_errors=True) + + +@pytest.mark.skipif(not shutil.which("bash"), reason="bash not installed") +def test_bash_prompt_survives_user_profile_override(): + home_dir = tempfile.mkdtemp() + try: + with open(os.path.join(home_dir, ".bashrc"), "w") as f: + f.write('PS1="user-prompt> "\n') + + init_content = _generate_shell_init("bash", use_profiles=True, j_commands=["power"]) + init_content += ( + f'PS1="{ANSI_GRAY}{PROMPT_CWD} {ANSI_YELLOW}⚡{ANSI_WHITE}' + '$_JMP_SHELL_CONTEXT' + f' {ANSI_YELLOW}➤{ANSI_RESET} "\n' + ) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".sh", delete=False) as f: + f.write(init_content) + rcfile = f.name + + try: + result = subprocess.run( + ["bash", "-c", f'source {rcfile}; echo "$PS1"'], + env={ + "HOME": home_dir, + "PATH": os.environ.get("PATH", ""), + "_JMP_SHELL_CONTEXT": "test-device", + }, + capture_output=True, + text=True, + timeout=10, + ) + assert result.returncode == 0, f"bash failed: {result.stderr}" + assert "user-prompt" not in result.stdout + assert "test-device" in result.stdout + finally: + os.unlink(rcfile) + finally: + shutil.rmtree(home_dir, ignore_errors=True)