diff --git a/python/packages/jumpstarter-driver-ssh-mount/README.md b/python/packages/jumpstarter-driver-ssh-mount/README.md index a2ccdef4c..2179cd86e 100644 --- a/python/packages/jumpstarter-driver-ssh-mount/README.md +++ b/python/packages/jumpstarter-driver-ssh-mount/README.md @@ -74,6 +74,29 @@ Ctrl+C to unmount. The `--umount` flag is available as a fallback for mounts that were orphaned (e.g., if the process was killed without cleanup). +## Security: `allow_other` mount option + +By default, sshfs is invoked with `-o allow_other`, which permits all local +users to access the mounted filesystem — not just the user who ran `j mount`. +This is convenient for build workflows where tools run under different UIDs, +but it has security implications on multi-user systems: + +- Any local user can read (and potentially write) files on the remote device + through the mountpoint. +- The option requires that `/etc/fuse.conf` contains `user_allow_other`; + otherwise the mount will fail. + +**Automatic fallback:** if `allow_other` is rejected by FUSE (e.g., +`user_allow_other` is not set), the driver automatically retries the mount +without it. In that case only the mounting user can access the filesystem. + +To explicitly disable `allow_other` without relying on the fallback, you can +override the option via `--extra-args`: + +```shell +j mount /mnt/device -o allow_other=0 +``` + ## API Reference ### SSHMountClient diff --git a/python/packages/jumpstarter-driver-ssh-mount/jumpstarter_driver_ssh_mount/client.py b/python/packages/jumpstarter-driver-ssh-mount/jumpstarter_driver_ssh_mount/client.py index c1520918f..1ecd90cfb 100644 --- a/python/packages/jumpstarter-driver-ssh-mount/jumpstarter_driver_ssh_mount/client.py +++ b/python/packages/jumpstarter-driver-ssh-mount/jumpstarter_driver_ssh_mount/client.py @@ -4,23 +4,28 @@ import shutil import subprocess import sys -import tempfile +import time from dataclasses import dataclass from urllib.parse import urlparse import click -from jumpstarter_driver_composite.client import CompositeClient from jumpstarter_driver_network.adapters import TcpPortforwardAdapter +from jumpstarter_driver_ssh._ssh_utils import cleanup_identity_file, create_temp_identity_file +from jumpstarter.client import DriverClient from jumpstarter.client.core import DriverMethodNotImplemented from jumpstarter.client.decorators import driver_click_command # Timeout in seconds for subprocess calls (mount test run, umount) SUBPROCESS_TIMEOUT = 120 +# Polling parameters for mount readiness check +MOUNT_POLL_INTERVAL = 0.5 +MOUNT_POLL_TIMEOUT = 10.0 + @dataclass(kw_only=True) -class SSHMountClient(CompositeClient): +class SSHMountClient(DriverClient): def cli(self): @driver_click_command(self) @@ -46,6 +51,10 @@ def mount(mountpoint, umount, remote_path, direct, lazy, foreground, extra_args) return mount + @property + def ssh(self): + return self.children["ssh"] + @property def identity(self) -> str | None: return self.ssh.identity @@ -106,7 +115,7 @@ def mount(self, mountpoint, *, remote_path="/", direct=False, foreground=False, foreground=foreground) def _run_sshfs(self, host, port, mountpoint, remote_path, extra_args, *, foreground): - identity_file = self._create_temp_identity_file() + identity_file = create_temp_identity_file(self.identity, self.logger) sshfs_proc = None try: @@ -145,7 +154,7 @@ def _run_sshfs(self, host, port, mountpoint, remote_path, extra_args, *, foregro self.logger.warning("Mountpoint %s may still be mounted after cleanup", mountpoint) else: click.echo(f"Unmounted {mountpoint}") - self._cleanup_identity_file(identity_file) + cleanup_identity_file(identity_file, self.logger) def _start_sshfs_with_fallback(self, sshfs_args, mountpoint): """Start sshfs, retrying without allow_other if it fails on that option. @@ -170,26 +179,23 @@ def _start_sshfs_with_fallback(self, sshfs_args, mountpoint): self._force_umount(mountpoint) - # Use DEVNULL for stderr to avoid SIGPIPE: if we used PIPE and - # closed the parent end after the startup check, sshfs would - # receive SIGPIPE on its next stderr write and terminate. proc = subprocess.Popen( sshfs_args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) - # Give sshfs a moment to start and check it hasn't failed immediately - try: - proc.wait(timeout=1) - # If it exited within 1s, something went wrong - raise click.ClickException( - f"sshfs mount failed immediately (exit code {proc.returncode})" - ) - except subprocess.TimeoutExpired: - # Good -- sshfs is still running after 1s. - # Verify the mount is actually active. - if not os.path.ismount(mountpoint): + # Poll until mount is ready or sshfs exits unexpectedly + deadline = time.monotonic() + MOUNT_POLL_TIMEOUT + while True: + ret = proc.poll() + if ret is not None: + raise click.ClickException( + f"sshfs mount failed immediately (exit code {ret})" + ) + if os.path.ismount(mountpoint): + break + if time.monotonic() >= deadline: proc.terminate() try: proc.wait(timeout=5) @@ -197,8 +203,9 @@ def _start_sshfs_with_fallback(self, sshfs_args, mountpoint): proc.kill() proc.wait() raise click.ClickException( - f"sshfs started but {mountpoint} is not mounted" - ) from None + f"sshfs started but {mountpoint} is not mounted after {MOUNT_POLL_TIMEOUT}s" + ) + time.sleep(MOUNT_POLL_INTERVAL) return proc @@ -269,44 +276,6 @@ def _build_sshfs_args(self, host, port, mountpoint, remote_path, identity_file, return sshfs_args - def _create_temp_identity_file(self): - ssh_identity = self.identity - if not ssh_identity: - return None - - fd = None - temp_path = None - try: - # mkstemp creates the file with 0o600 permissions atomically, - # avoiding the TOCTOU window of NamedTemporaryFile + chmod. - fd, temp_path = tempfile.mkstemp(suffix='_ssh_key') - os.write(fd, ssh_identity.encode()) - os.close(fd) - fd = None - self.logger.debug("Created temporary identity file: %s", temp_path) - return temp_path - except Exception as e: - self.logger.error("Failed to create temporary identity file: %s", e) - if fd is not None: - try: - os.close(fd) - except Exception: - pass - if temp_path: - try: - os.unlink(temp_path) - except Exception: - pass - raise - - def _cleanup_identity_file(self, identity_file): - if identity_file: - try: - os.unlink(identity_file) - self.logger.debug("Cleaned up temporary identity file: %s", identity_file) - except Exception as e: - self.logger.warning("Failed to clean up identity file %s: %s", identity_file, e) - def umount(self, mountpoint, *, lazy=False): """Unmount an sshfs filesystem (fallback for orphaned mounts).""" mountpoint = os.path.realpath(mountpoint) diff --git a/python/packages/jumpstarter-driver-ssh-mount/jumpstarter_driver_ssh_mount/driver_test.py b/python/packages/jumpstarter-driver-ssh-mount/jumpstarter_driver_ssh_mount/driver_test.py index a45a9e139..9c36f9e96 100644 --- a/python/packages/jumpstarter-driver-ssh-mount/jumpstarter_driver_ssh_mount/driver_test.py +++ b/python/packages/jumpstarter-driver-ssh-mount/jumpstarter_driver_ssh_mount/driver_test.py @@ -1,17 +1,16 @@ import os -import subprocess from unittest.mock import MagicMock, patch import pytest from jumpstarter_driver_network.driver import TcpNetwork from jumpstarter_driver_ssh.driver import SSHWrapper +from jumpstarter_driver_ssh_mount.client import MOUNT_POLL_INTERVAL from jumpstarter_driver_ssh_mount.driver import SSHMount from jumpstarter.common.exceptions import ConfigurationError from jumpstarter.common.utils import serve -# Test SSH key content used in multiple tests TEST_SSH_KEY = ( "-----BEGIN OPENSSH PRIVATE KEY-----\n" "test-key-content\n" @@ -33,424 +32,488 @@ def _make_ssh_child(default_username="testuser", ssh_identity=None, ssh_identity return SSHWrapper(**kwargs) +def _fake_find_executable(name): + """Return plausible paths per executable name.""" + paths = { + "sshfs": "/usr/bin/sshfs", + "fusermount3": "/usr/bin/fusermount3", + "fusermount": "/usr/bin/fusermount", + } + return paths.get(name) + + +@pytest.fixture +def mount_instance(): + return SSHMount(children={"ssh": _make_ssh_child()}) + + +@pytest.fixture +def mount_instance_with_identity(): + return SSHMount(children={"ssh": _make_ssh_child(ssh_identity=TEST_SSH_KEY)}) + + +@pytest.fixture +def mock_portforward(): + with patch('jumpstarter_driver_ssh_mount.client.TcpPortforwardAdapter') as mock_adapter: + mock_adapter.return_value.__enter__ = MagicMock(return_value=("127.0.0.1", 2222)) + mock_adapter.return_value.__exit__ = MagicMock(return_value=None) + yield mock_adapter + + +@pytest.fixture +def mock_portforward_22(): + with patch('jumpstarter_driver_ssh_mount.client.TcpPortforwardAdapter') as mock_adapter: + mock_adapter.return_value.__enter__ = MagicMock(return_value=("127.0.0.1", 22)) + mock_adapter.return_value.__exit__ = MagicMock(return_value=None) + yield mock_adapter + + +# --------------------------------------------------------------------------- +# Driver configuration tests +# --------------------------------------------------------------------------- + def test_ssh_mount_requires_ssh_child(): """Test that SSHMount driver requires an ssh child""" with pytest.raises(ConfigurationError, match="'ssh' child is required"): SSHMount() -def test_mount_sshfs_not_installed(): - """Test mount fails gracefully when sshfs is not installed""" - instance = SSHMount( - children={"ssh": _make_ssh_child()}, - ) +# --------------------------------------------------------------------------- +# _build_sshfs_args unit tests (argument construction validated independently) +# --------------------------------------------------------------------------- + +def test_build_sshfs_args_basic(mount_instance): + """Test basic sshfs argument construction""" + with serve(mount_instance) as client: + args = client._build_sshfs_args("192.168.1.1", 22, "/mnt/remote", "/", None, None) + assert args[0] == "sshfs" + assert "testuser@192.168.1.1:/" in args + assert "/mnt/remote" in args + assert "-p" not in args + + +def test_build_sshfs_args_custom_port(mount_instance): + """Test sshfs args include -p for non-default port""" + with serve(mount_instance) as client: + args = client._build_sshfs_args("192.168.1.1", 2222, "/mnt/remote", "/", None, None) + assert "-p" in args + assert "2222" in args + + +def test_build_sshfs_args_with_identity(mount_instance): + """Test sshfs args include IdentityFile when identity file is provided""" + with serve(mount_instance) as client: + args = client._build_sshfs_args("192.168.1.1", 22, "/mnt/remote", "/", + "/tmp/my_key", None) + identity_opts = [args[i + 1] for i in range(len(args) - 1) + if args[i] == "-o" and args[i + 1].startswith("IdentityFile=")] + assert len(identity_opts) == 1 + assert identity_opts[0] == "IdentityFile=/tmp/my_key" + +def test_build_sshfs_args_allow_other_present(mount_instance): + """Test sshfs args include allow_other by default""" + with serve(mount_instance) as client: + args = client._build_sshfs_args("192.168.1.1", 22, "/mnt/remote", "/", None, None) + assert "allow_other" in args + + +def test_build_sshfs_args_with_extra_args(mount_instance): + """Test extra args are prefixed with -o""" + with serve(mount_instance) as client: + args = client._build_sshfs_args("192.168.1.1", 22, "/mnt/remote", "/", None, + ["reconnect", "cache=yes"]) + for extra in ["reconnect", "cache=yes"]: + idx = args.index(extra) + assert args[idx - 1] == "-o" + + +def test_build_sshfs_args_remote_path(mount_instance): + """Test sshfs args use the correct remote path""" + with serve(mount_instance) as client: + args = client._build_sshfs_args("10.0.0.1", 22, "/mnt/remote", "/home/user", None, None) + assert "testuser@10.0.0.1:/home/user" in args + + +def test_build_sshfs_args_no_username(): + """Test sshfs args without default username""" + instance = SSHMount(children={"ssh": _make_ssh_child(default_username="")}) with serve(instance) as client: + args = client._build_sshfs_args("10.0.0.1", 22, "/mnt/remote", "/", None, None) + assert "10.0.0.1:/" in args + assert not any("@" in a for a in args if ":" in a) + + +# --------------------------------------------------------------------------- +# Mount workflow tests +# --------------------------------------------------------------------------- + +def test_mount_sshfs_not_installed(mount_instance): + """Test mount fails gracefully when sshfs is not installed""" + with serve(mount_instance) as client: with patch.object(client, '_find_executable', return_value=None): with pytest.raises(Exception, match="sshfs is not installed"): client.mount("/tmp/test-mount") -def test_mount_sshfs_success(): +def test_mount_sshfs_success(mount_instance, mock_portforward): """Test successful sshfs mount via port forwarding with subshell""" - instance = SSHMount( - children={"ssh": _make_ssh_child()}, - ) - - with serve(instance) as client: + with serve(mount_instance) as client: mock_proc = MagicMock() - mock_proc.poll.return_value = 0 # sshfs already exited + mock_proc.poll.return_value = 0 mock_proc.stderr = None - with patch.object(client, '_find_executable', return_value="/usr/bin/sshfs"): - with patch('subprocess.run') as mock_run: - with patch('subprocess.Popen', return_value=mock_proc): - # Test run succeeds, then foreground popen exits immediately (simulated) - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") - mock_proc.wait.side_effect = [None] # wait returns immediately (exited) - - with patch('os.makedirs'): - with patch('jumpstarter_driver_ssh_mount.client.TcpPortforwardAdapter') as mock_adapter: - mock_adapter.return_value.__enter__ = MagicMock(return_value=("127.0.0.1", 2222)) - mock_adapter.return_value.__exit__ = MagicMock(return_value=None) - - # The foreground popen will fail because sshfs exits immediately, - # which raises ClickException. That's expected in unit tests - # where sshfs isn't really running. - with pytest.raises(Exception, match="sshfs mount failed"): - client.mount("/tmp/test-mount", remote_path="/home/user") - - # Verify test run was called with correct args - test_run_args = mock_run.call_args_list[0][0][0] - assert test_run_args[0] == "sshfs" - assert "testuser@127.0.0.1:/home/user" in test_run_args - assert os.path.realpath("/tmp/test-mount") in test_run_args - assert "-p" in test_run_args - assert "2222" in test_run_args - # -f should NOT be in the test run (it's removed for validation) - assert "-f" not in test_run_args - - -def test_mount_sshfs_with_identity(): - """Test sshfs mount with SSH identity""" - instance = SSHMount( - children={"ssh": _make_ssh_child(ssh_identity=TEST_SSH_KEY)}, - ) + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + mock_proc.wait.side_effect = [None] - with serve(instance) as client: + with pytest.raises(Exception, match="sshfs mount failed"): + client.mount("/tmp/test-mount", remote_path="/home/user") + + test_run_args = mock_run.call_args_list[0][0][0] + assert test_run_args[0] == "sshfs" + assert "testuser@127.0.0.1:/home/user" in test_run_args + assert os.path.realpath("/tmp/test-mount") in test_run_args + assert "-p" in test_run_args + assert "2222" in test_run_args + assert "-f" not in test_run_args + + +def test_mount_sshfs_with_identity(mount_instance_with_identity, mock_portforward_22): + """Test sshfs mount with SSH identity""" + with serve(mount_instance_with_identity) as client: mock_proc = MagicMock() mock_proc.poll.return_value = 0 mock_proc.stderr = None - with patch.object(client, '_find_executable', return_value="/usr/bin/sshfs"): - with patch('subprocess.run') as mock_run: - with patch('subprocess.Popen', return_value=mock_proc): - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") - mock_proc.wait.side_effect = [None] - - with patch('os.makedirs'): - with patch('jumpstarter_driver_ssh_mount.client.TcpPortforwardAdapter') as mock_adapter: - mock_adapter.return_value.__enter__ = MagicMock(return_value=("127.0.0.1", 22)) - mock_adapter.return_value.__exit__ = MagicMock(return_value=None) - - with pytest.raises(Exception, match="sshfs mount failed"): - client.mount("/tmp/test-mount") + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + mock_proc.wait.side_effect = [None] - test_run_args = mock_run.call_args_list[0][0][0] - identity_opts = [ - test_run_args[i + 1] for i in range(len(test_run_args) - 1) - if test_run_args[i] == "-o" and test_run_args[i + 1].startswith("IdentityFile=") - ] - assert len(identity_opts) == 1 + with pytest.raises(Exception, match="sshfs mount failed"): + client.mount("/tmp/test-mount") + test_run_args = mock_run.call_args_list[0][0][0] + identity_opts = [ + test_run_args[i + 1] for i in range(len(test_run_args) - 1) + if test_run_args[i] == "-o" and test_run_args[i + 1].startswith("IdentityFile=") + ] + assert len(identity_opts) == 1 -def test_mount_sshfs_allow_other_fallback(): - """Test sshfs mount falls back when allow_other fails, removing both -o and allow_other""" - instance = SSHMount( - children={"ssh": _make_ssh_child()}, - ) - with serve(instance) as client: +def test_mount_sshfs_allow_other_fallback(mount_instance, mock_portforward_22): + """Test sshfs mount falls back when allow_other fails""" + with serve(mount_instance) as client: mock_proc = MagicMock() mock_proc.poll.return_value = 0 mock_proc.stderr = None - with patch.object(client, '_find_executable', return_value="/usr/bin/sshfs"): - with patch('subprocess.run') as mock_run: - with patch('subprocess.Popen', return_value=mock_proc): - # First test run fails with allow_other, second succeeds - mock_run.side_effect = [ - MagicMock(returncode=1, stdout="", stderr="allow_other: permission denied"), - MagicMock(returncode=0, stdout="", stderr=""), # retry without allow_other - MagicMock(returncode=0, stdout="", stderr=""), # force_umount - ] - mock_proc.wait.side_effect = [None] - - with patch('os.makedirs'): - with patch('jumpstarter_driver_ssh_mount.client.TcpPortforwardAdapter') as mock_adapter: - mock_adapter.return_value.__enter__ = MagicMock(return_value=("127.0.0.1", 22)) - mock_adapter.return_value.__exit__ = MagicMock(return_value=None) - - with pytest.raises(Exception, match="sshfs mount failed"): - client.mount("/tmp/test-mount") - - # Second test run should not have allow_other - second_call_args = mock_run.call_args_list[1][0][0] - assert "allow_other" not in second_call_args - # Verify no orphaned -o flags - for i, arg in enumerate(second_call_args): - if arg == "-o": - assert i + 1 < len(second_call_args), "Orphaned -o flag found" - assert not second_call_args[i + 1].startswith("-"), \ - f"Orphaned -o flag followed by {second_call_args[i + 1]}" - - -def test_mount_sshfs_generic_failure(): + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + ): + mock_run.side_effect = [ + MagicMock(returncode=1, stdout="", stderr="allow_other: permission denied"), + MagicMock(returncode=0, stdout="", stderr=""), + MagicMock(returncode=0, stdout="", stderr=""), + ] + mock_proc.wait.side_effect = [None] + + with pytest.raises(Exception, match="sshfs mount failed"): + client.mount("/tmp/test-mount") + + second_call_args = mock_run.call_args_list[1][0][0] + assert "allow_other" not in second_call_args + for i, arg in enumerate(second_call_args): + if arg == "-o": + assert i + 1 < len(second_call_args), "Orphaned -o flag found" + assert not second_call_args[i + 1].startswith("-"), \ + f"Orphaned -o flag followed by {second_call_args[i + 1]}" + + +def test_mount_sshfs_generic_failure(mount_instance, mock_portforward_22): """Test mount failure with a non-allow_other error""" - instance = SSHMount( - children={"ssh": _make_ssh_child()}, - ) + with serve(mount_instance) as client: + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('os.makedirs'), + ): + mock_run.return_value = MagicMock(returncode=1, stdout="", stderr="Connection refused") + + with pytest.raises(Exception, match="sshfs mount failed"): + client.mount("/tmp/test-mount") - with serve(instance) as client: - with patch.object(client, '_find_executable', return_value="/usr/bin/sshfs"): - with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock( - returncode=1, stdout="", stderr="Connection refused" - ) - with patch('os.makedirs'): - with patch('jumpstarter_driver_ssh_mount.client.TcpPortforwardAdapter') as mock_adapter: - mock_adapter.return_value.__enter__ = MagicMock(return_value=("127.0.0.1", 22)) - mock_adapter.return_value.__exit__ = MagicMock(return_value=None) - - with pytest.raises(Exception, match="sshfs mount failed"): - client.mount("/tmp/test-mount") - - # First call is the sshfs test run (should not retry since - # error is not allow_other). Second call is _force_umount - # in the finally block cleanup. - assert mock_run.call_count == 2 - # Verify the first call was the sshfs test run - first_call_args = mock_run.call_args_list[0][0][0] - assert first_call_args[0] == "sshfs" + assert mock_run.call_count == 2 + first_call_args = mock_run.call_args_list[0][0][0] + assert first_call_args[0] == "sshfs" def test_mount_sshfs_direct_success(): """Test sshfs mount using direct TCP address""" - instance = SSHMount( - children={"ssh": _make_ssh_child(host="10.0.0.1", port=2222)}, - ) + instance = SSHMount(children={"ssh": _make_ssh_child(host="10.0.0.1", port=2222)}) with serve(instance) as client: mock_proc = MagicMock() mock_proc.poll.return_value = 0 mock_proc.stderr = None - with patch.object(client, '_find_executable', return_value="/usr/bin/sshfs"): - with patch('subprocess.run') as mock_run: - with patch('subprocess.Popen', return_value=mock_proc): - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") - mock_proc.wait.side_effect = [None] + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + mock_proc.wait.side_effect = [None] - with patch('os.makedirs'): - with pytest.raises(Exception, match="sshfs mount failed"): - client.mount("/tmp/test-mount", direct=True) + with pytest.raises(Exception, match="sshfs mount failed"): + client.mount("/tmp/test-mount", direct=True) - test_run_args = mock_run.call_args_list[0][0][0] - assert test_run_args[0] == "sshfs" - assert "testuser@10.0.0.1:/" in test_run_args - assert "-p" in test_run_args - assert "2222" in test_run_args + test_run_args = mock_run.call_args_list[0][0][0] + assert test_run_args[0] == "sshfs" + assert "testuser@10.0.0.1:/" in test_run_args + assert "-p" in test_run_args + assert "2222" in test_run_args -def test_mount_sshfs_direct_fallback_to_portforward(): +def test_mount_sshfs_direct_fallback_to_portforward(mount_instance, mock_portforward): """Test that direct mount falls back to port forwarding on failure""" - instance = SSHMount( - children={"ssh": _make_ssh_child()}, - ) - - with serve(instance) as client: + with serve(mount_instance) as client: mock_proc = MagicMock() mock_proc.poll.return_value = 0 mock_proc.stderr = None - with patch.object(client, '_find_executable', return_value="/usr/bin/sshfs"): - with patch('subprocess.run') as mock_run: - with patch('subprocess.Popen', return_value=mock_proc): - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") - mock_proc.wait.side_effect = [None] + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + mock_proc.wait.side_effect = [None] - with patch('os.makedirs'): - with patch('jumpstarter_driver_ssh_mount.client.TcpPortforwardAdapter') as mock_adapter: - mock_adapter.return_value.__enter__ = MagicMock(return_value=("127.0.0.1", 3333)) - mock_adapter.return_value.__exit__ = MagicMock(return_value=None) + original_ssh = client.ssh - original_ssh = client.ssh + class FakeTcp: + def address(self): + raise ValueError("not available") - class FakeTcp: - def address(self): - raise ValueError("not available") + class FakeSsh: + def __getattr__(self, name): + if name == "tcp": + return FakeTcp() + return getattr(original_ssh, name) - class FakeSsh: - def __getattr__(self, name): - if name == "tcp": - return FakeTcp() - return getattr(original_ssh, name) + with patch.object(client, 'children', {**client.children, "ssh": FakeSsh()}): + with pytest.raises(Exception, match="sshfs mount failed"): + client.mount("/tmp/test-mount", direct=True) - with patch.object(client, 'ssh', FakeSsh()): - with pytest.raises(Exception, match="sshfs mount failed"): - client.mount("/tmp/test-mount", direct=True) + test_run_args = mock_run.call_args_list[0][0][0] + assert "2222" in test_run_args - test_run_args = mock_run.call_args_list[0][0][0] - # Should have used port forwarding (port 3333) - assert "3333" in test_run_args - -def test_mount_foreground_mode(): +def test_mount_foreground_mode(mount_instance, mock_portforward_22): """Test that foreground flag blocks on sshfs without spawning subshell""" - instance = SSHMount( - children={"ssh": _make_ssh_child()}, - ) - - with serve(instance) as client: + with serve(mount_instance) as client: mock_proc = MagicMock() - mock_proc.poll.return_value = None # Still running when cleanup checks - mock_proc.wait.side_effect = [ - subprocess.TimeoutExpired("sshfs", 1), # First wait (startup check) - still running - None, # Second wait (foreground blocking) - exited - None, # Third wait (cleanup after terminate) - exited - ] + mock_proc.poll.return_value = None mock_proc.returncode = 0 - with patch.object(client, '_find_executable', return_value="/usr/bin/sshfs"): - with patch('subprocess.run') as mock_run: - with patch('subprocess.Popen', return_value=mock_proc) as mock_popen: - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + poll_calls = [0] + def poll_side_effect(): + poll_calls[0] += 1 + if poll_calls[0] >= 3: + return None + return None + mock_proc.poll.side_effect = poll_side_effect - with patch('os.makedirs'): - with patch('os.path.ismount', return_value=True): - with patch('jumpstarter_driver_ssh_mount.client.TcpPortforwardAdapter') as mock_adapter: - mock_adapter.return_value.__enter__ = MagicMock(return_value=("127.0.0.1", 22)) - mock_adapter.return_value.__exit__ = MagicMock(return_value=None) + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc) as mock_popen, + patch('os.makedirs'), + patch('os.path.ismount', return_value=True), + patch('jumpstarter_driver_ssh_mount.client.time.sleep'), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + mock_proc.wait.return_value = None - client.mount("/tmp/test-mount", foreground=True) + client.mount("/tmp/test-mount", foreground=True) - # Should have waited on sshfs (foreground mode) - assert mock_proc.wait.call_count >= 2 - # Port forward should be cleaned up - mock_adapter.return_value.__exit__.assert_called() - # Verify -f flag is in the Popen args - popen_args = mock_popen.call_args[0][0] - assert "-f" in popen_args + assert mock_proc.wait.call_count >= 1 + mock_portforward_22.return_value.__exit__.assert_called() + popen_args = mock_popen.call_args[0][0] + assert "-f" in popen_args -def test_mount_subshell_mode(): +def test_mount_subshell_mode(mount_instance, mock_portforward_22): """Test that default mode spawns a subshell""" - instance = SSHMount( - children={"ssh": _make_ssh_child()}, - ) - - with serve(instance) as client: + with serve(mount_instance) as client: mock_proc = MagicMock() - mock_proc.poll.return_value = None # Still running when cleanup checks - mock_proc.wait.side_effect = [ - subprocess.TimeoutExpired("sshfs", 1), # Startup check - still running - None, # Cleanup wait after terminate - exited - ] + mock_proc.poll.return_value = None mock_proc.returncode = 0 - with patch.object(client, '_find_executable', return_value="/usr/bin/sshfs"): - with patch('subprocess.run') as mock_run: - with patch('subprocess.Popen', return_value=mock_proc): - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") - - with patch('os.makedirs'): - with patch('os.path.ismount', return_value=True): - with patch('jumpstarter_driver_ssh_mount.client.TcpPortforwardAdapter') as mock_adapter: - mock_adapter.return_value.__enter__ = MagicMock(return_value=("127.0.0.1", 22)) - mock_adapter.return_value.__exit__ = MagicMock(return_value=None) + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + patch('os.path.ismount', return_value=True), + patch('jumpstarter_driver_ssh_mount.client.time.sleep'), + patch.object(client, '_run_subshell') as mock_subshell, + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") - with patch.object(client, '_run_subshell') as mock_subshell: - client.mount("/tmp/test-mount") + client.mount("/tmp/test-mount") - # Subshell should have been called - resolved = os.path.realpath("/tmp/test-mount") - mock_subshell.assert_called_once_with(resolved, "/") + resolved = os.path.realpath("/tmp/test-mount") + mock_subshell.assert_called_once_with(resolved, "/") -def test_mount_cleanup_on_failure(): +def test_mount_cleanup_on_failure(mount_instance_with_identity, mock_portforward_22): """Test that identity file is cleaned up when mount fails""" - instance = SSHMount( - children={"ssh": _make_ssh_child(ssh_identity=TEST_SSH_KEY)}, - ) + with serve(mount_instance_with_identity) as client: + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('os.makedirs'), + patch('os.unlink') as mock_unlink, + ): + mock_run.return_value = MagicMock(returncode=1, stdout="", stderr="Connection refused") + + with pytest.raises(Exception, match="sshfs mount failed"): + client.mount("/tmp/test-mount") - with serve(instance) as client: - with patch.object(client, '_find_executable', return_value="/usr/bin/sshfs"): - with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock( - returncode=1, stdout="", stderr="Connection refused" - ) - with patch('os.makedirs'): - with patch('jumpstarter_driver_ssh_mount.client.TcpPortforwardAdapter') as mock_adapter: - mock_adapter.return_value.__enter__ = MagicMock(return_value=("127.0.0.1", 22)) - mock_adapter.return_value.__exit__ = MagicMock(return_value=None) - - with patch('os.unlink') as mock_unlink: - with pytest.raises(Exception, match="sshfs mount failed"): - client.mount("/tmp/test-mount") - - # Identity file should be cleaned up on failure - # Verify unlink was called with a path ending in _ssh_key - assert mock_unlink.called - unlink_path = mock_unlink.call_args_list[-1][0][0] - assert unlink_path.endswith("_ssh_key") - - -def test_umount_with_fusermount(): - """Test unmount using fusermount""" - instance = SSHMount( - children={"ssh": _make_ssh_child()}, - ) + assert mock_unlink.called + unlink_path = mock_unlink.call_args_list[-1][0][0] + assert unlink_path.endswith("_ssh_key") - with serve(instance) as client: - def _fake_find(name): - return "/usr/bin/fusermount" if name == "fusermount" else None - with patch.object(client, '_find_executable', side_effect=_fake_find): - with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") +# --------------------------------------------------------------------------- +# Unmount tests +# --------------------------------------------------------------------------- - client.umount("/tmp/test-mount") +def test_umount_with_fusermount(mount_instance): + """Test unmount using fusermount""" + with serve(mount_instance) as client: + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + client.umount("/tmp/test-mount") - assert mock_run.called - call_args = mock_run.call_args[0][0] - assert call_args[0] == "/usr/bin/fusermount" - assert "-u" in call_args + call_args = mock_run.call_args[0][0] + assert call_args[0] == "/usr/bin/fusermount3" + assert "-u" in call_args -def test_umount_with_system_umount_fallback(): +def test_umount_with_system_umount_fallback(mount_instance): """Test unmount falls back to system umount when fusermount is not available""" - instance = SSHMount( - children={"ssh": _make_ssh_child()}, - ) + with serve(mount_instance) as client: + with ( + patch.object(client, '_find_executable', return_value=None), + patch('subprocess.run') as mock_run, + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + client.umount("/tmp/test-mount") - with serve(instance) as client: - with patch.object(client, '_find_executable', return_value=None): - with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + call_args = mock_run.call_args[0][0] + assert call_args[0] == "umount" - client.umount("/tmp/test-mount") - assert mock_run.called - call_args = mock_run.call_args[0][0] - assert call_args[0] == "umount" +def test_umount_lazy(mount_instance): + """Test lazy unmount""" + with serve(mount_instance) as client: + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + client.umount("/tmp/test-mount", lazy=True) + call_args = mock_run.call_args[0][0] + assert "-z" in call_args -def test_umount_lazy(): - """Test lazy unmount""" - instance = SSHMount( - children={"ssh": _make_ssh_child()}, - ) - with serve(instance) as client: - def _fake_find(name): - return "/usr/bin/fusermount" if name == "fusermount" else None +def test_umount_failure(mount_instance): + """Test unmount failure""" + with serve(mount_instance) as client: + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + ): + mock_run.return_value = MagicMock(returncode=1, stdout="", stderr="not mounted") + + with pytest.raises(Exception, match="Unmount failed"): + client.umount("/tmp/test-mount") - with patch.object(client, '_find_executable', side_effect=_fake_find): - with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") - client.umount("/tmp/test-mount", lazy=True) +def test_umount_prefers_fusermount3(mount_instance): + """Test that fusermount3 is preferred over fusermount when both are available""" + with serve(mount_instance) as client: + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + client.umount("/tmp/test-mount") - assert mock_run.called - call_args = mock_run.call_args[0][0] - assert "-z" in call_args + call_args = mock_run.call_args[0][0] + assert call_args[0] == "/usr/bin/fusermount3" -def test_umount_failure(): - """Test unmount failure""" - instance = SSHMount( - children={"ssh": _make_ssh_child()}, - ) +def test_umount_lazy_macos_uses_force(mount_instance): + """Test that lazy unmount on macOS uses -f instead of -l""" + with serve(mount_instance) as client: + with ( + patch.object(client, '_find_executable', return_value=None), + patch('subprocess.run') as mock_run, + patch('jumpstarter_driver_ssh_mount.client.sys') as mock_sys, + ): + mock_sys.platform = "darwin" + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") - with serve(instance) as client: - def _fake_find(name): - return "/usr/bin/fusermount" if name == "fusermount" else None + client.umount("/tmp/test-mount", lazy=True) - with patch.object(client, '_find_executable', side_effect=_fake_find): - with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock(returncode=1, stdout="", stderr="not mounted") + call_args = mock_run.call_args[0][0] + assert "-f" in call_args + assert "-l" not in call_args - with pytest.raises(Exception, match="Unmount failed"): - client.umount("/tmp/test-mount") +def test_umount_passes_timeout(mount_instance): + """Test that umount subprocess calls include SUBPROCESS_TIMEOUT""" + with serve(mount_instance) as client: + with ( + patch.object(client, '_find_executable', return_value=None), + patch('subprocess.run') as mock_run, + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + client.umount("/tmp/test-mount") -def test_cli_has_mount_and_umount_flag(): - """Test that the CLI exposes mount command with --umount and --foreground flags""" - instance = SSHMount( - children={"ssh": _make_ssh_child()}, - ) + assert mock_run.call_args[1].get("timeout") == 120 - with serve(instance) as client: + +# --------------------------------------------------------------------------- +# CLI tests +# --------------------------------------------------------------------------- + +def test_cli_has_mount_and_umount_flag(mount_instance): + """Test that the CLI exposes mount command with --umount and --foreground flags""" + with serve(mount_instance) as client: cli = client.cli() from click.testing import CliRunner runner = CliRunner() @@ -460,13 +523,9 @@ def test_cli_has_mount_and_umount_flag(): assert "--foreground" in result.output -def test_cli_dispatches_mount(): +def test_cli_dispatches_mount(mount_instance): """Test that CLI invocation with a mountpoint dispatches to self.mount()""" - instance = SSHMount( - children={"ssh": _make_ssh_child()}, - ) - - with serve(instance) as client: + with serve(mount_instance) as client: cli = client.cli() from click.testing import CliRunner runner = CliRunner() @@ -483,13 +542,9 @@ def test_cli_dispatches_mount(): ) -def test_cli_dispatches_umount(): +def test_cli_dispatches_umount(mount_instance): """Test that CLI invocation with --umount dispatches to self.umount()""" - instance = SSHMount( - children={"ssh": _make_ssh_child()}, - ) - - with serve(instance) as client: + with serve(mount_instance) as client: cli = client.cli() from click.testing import CliRunner runner = CliRunner() @@ -500,202 +555,177 @@ def test_cli_dispatches_umount(): mock_umount.assert_called_once_with("/tmp/test-cli-mount", lazy=True) -def test_mount_foreground_keyboard_interrupt(): - """Test that KeyboardInterrupt during foreground mode terminates sshfs and unmounts""" - instance = SSHMount( - children={"ssh": _make_ssh_child()}, - ) +# --------------------------------------------------------------------------- +# Polling / mount-readiness tests +# --------------------------------------------------------------------------- - with serve(instance) as client: +def test_mount_polling_waits_for_mount(mount_instance, mock_portforward_22): + """Test that the polling loop waits for os.path.ismount to return True""" + with serve(mount_instance) as client: mock_proc = MagicMock() - mock_proc.poll.return_value = None # Still running - mock_proc.wait.side_effect = [ - subprocess.TimeoutExpired("sshfs", 1), # Startup check - still running - KeyboardInterrupt(), # Foreground blocking - user presses Ctrl+C - None, # Cleanup wait after terminate - ] + mock_proc.poll.return_value = None mock_proc.returncode = 0 - with patch.object(client, '_find_executable', return_value="/usr/bin/sshfs"): - with patch('subprocess.run') as mock_run: - with patch('subprocess.Popen', return_value=mock_proc): - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + ismount_calls = [0] + def ismount_side_effect(path): + ismount_calls[0] += 1 + return ismount_calls[0] >= 3 - with patch('os.makedirs'): - with patch('os.path.ismount', return_value=True): - with patch('jumpstarter_driver_ssh_mount.client.TcpPortforwardAdapter') as mock_adapter: - mock_adapter.return_value.__enter__ = MagicMock(return_value=("127.0.0.1", 22)) - mock_adapter.return_value.__exit__ = MagicMock(return_value=None) + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + patch('os.path.ismount', side_effect=ismount_side_effect), + patch('jumpstarter_driver_ssh_mount.client.time.sleep') as mock_sleep, + patch.object(client, '_run_subshell'), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") - client.mount("/tmp/test-mount", foreground=True) + client.mount("/tmp/test-mount") - # sshfs should have been terminated - mock_proc.terminate.assert_called_once() + assert mock_sleep.call_count >= 2 + mock_sleep.assert_called_with(MOUNT_POLL_INTERVAL) -def test_umount_passes_timeout(): - """Test that umount subprocess calls include SUBPROCESS_TIMEOUT""" - instance = SSHMount( - children={"ssh": _make_ssh_child()}, - ) - - with serve(instance) as client: - with patch.object(client, '_find_executable', return_value=None): - with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") +def test_mount_polling_timeout(mount_instance, mock_portforward_22): + """Test that mount fails if mountpoint is never mounted within timeout""" + with serve(mount_instance) as client: + mock_proc = MagicMock() + mock_proc.poll.return_value = None + mock_proc.returncode = 0 - client.umount("/tmp/test-mount") + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + patch('os.path.ismount', return_value=False), + patch('jumpstarter_driver_ssh_mount.client.time.sleep'), + patch('jumpstarter_driver_ssh_mount.client.MOUNT_POLL_TIMEOUT', 0), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") - # Verify timeout=120 is passed - assert mock_run.call_args[1].get("timeout") == 120 + with pytest.raises(Exception, match="is not mounted"): + client.mount("/tmp/test-mount", foreground=True) + mock_proc.terminate.assert_called() -def test_mount_port_22_omits_p_flag(): - """Test that port 22 does not add -p flag to sshfs args""" - instance = SSHMount( - children={"ssh": _make_ssh_child(port=22)}, - ) - with serve(instance) as client: +def test_mount_sshfs_not_mounted_after_startup(mount_instance, mock_portforward_22): + """Test that mount fails if sshfs starts but mountpoint is not actually mounted""" + with serve(mount_instance) as client: mock_proc = MagicMock() - mock_proc.poll.return_value = 0 - mock_proc.stderr = None - - with patch.object(client, '_find_executable', return_value="/usr/bin/sshfs"): - with patch('subprocess.run') as mock_run: - with patch('subprocess.Popen', return_value=mock_proc): - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") - mock_proc.wait.side_effect = [None] - - with patch('os.makedirs'): - with patch('jumpstarter_driver_ssh_mount.client.TcpPortforwardAdapter') as mock_adapter: - mock_adapter.return_value.__enter__ = MagicMock(return_value=("127.0.0.1", 22)) - mock_adapter.return_value.__exit__ = MagicMock(return_value=None) - - with pytest.raises(Exception, match="sshfs mount failed"): - client.mount("/tmp/test-mount") + mock_proc.poll.return_value = None + mock_proc.returncode = 0 - test_run_args = mock_run.call_args_list[0][0][0] - assert "-p" not in test_run_args + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + patch('os.path.ismount', return_value=False), + patch('jumpstarter_driver_ssh_mount.client.time.sleep'), + patch('jumpstarter_driver_ssh_mount.client.MOUNT_POLL_TIMEOUT', 0), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + with pytest.raises(Exception, match="is not mounted"): + client.mount("/tmp/test-mount", foreground=True) -def test_umount_prefers_fusermount3(): - """Test that fusermount3 is preferred over fusermount when both are available""" - instance = SSHMount( - children={"ssh": _make_ssh_child()}, - ) + mock_proc.terminate.assert_called() - with serve(instance) as client: - def _fake_find(name): - if name == "fusermount3": - return "/usr/bin/fusermount3" - if name == "fusermount": - return "/usr/bin/fusermount" - return None - with patch.object(client, '_find_executable', side_effect=_fake_find): - with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") +# --------------------------------------------------------------------------- +# Foreground / KeyboardInterrupt tests +# --------------------------------------------------------------------------- - client.umount("/tmp/test-mount") - - call_args = mock_run.call_args[0][0] - assert call_args[0] == "/usr/bin/fusermount3" +def test_mount_foreground_keyboard_interrupt(mount_instance, mock_portforward_22): + """Test that KeyboardInterrupt during foreground mode terminates sshfs and unmounts""" + with serve(mount_instance) as client: + mock_proc = MagicMock() + mock_proc.poll.return_value = None + mock_proc.returncode = 0 + mock_proc.wait.side_effect = [ + KeyboardInterrupt(), + None, + ] -def test_umount_lazy_macos_uses_force(): - """Test that lazy unmount on macOS uses -f instead of -l""" - instance = SSHMount( - children={"ssh": _make_ssh_child()}, - ) + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + patch('os.path.ismount', return_value=True), + patch('jumpstarter_driver_ssh_mount.client.time.sleep'), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") - with serve(instance) as client: - with patch.object(client, '_find_executable', return_value=None): - with patch('subprocess.run') as mock_run: - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + client.mount("/tmp/test-mount", foreground=True) - with patch('jumpstarter_driver_ssh_mount.client.sys') as mock_sys: - mock_sys.platform = "darwin" - client.umount("/tmp/test-mount", lazy=True) + mock_proc.terminate.assert_called_once() - call_args = mock_run.call_args[0][0] - assert "-f" in call_args - assert "-l" not in call_args +# --------------------------------------------------------------------------- +# Extra args and port tests +# --------------------------------------------------------------------------- -def test_extra_args_prefixed_with_dash_o(): +def test_extra_args_prefixed_with_dash_o(mount_instance, mock_portforward_22): """Test that extra_args are correctly prefixed with -o in sshfs command""" - instance = SSHMount( - children={"ssh": _make_ssh_child()}, - ) - - with serve(instance) as client: + with serve(mount_instance) as client: mock_proc = MagicMock() mock_proc.poll.return_value = 0 mock_proc.stderr = None - with patch.object(client, '_find_executable', return_value="/usr/bin/sshfs"): - with patch('subprocess.run') as mock_run: - with patch('subprocess.Popen', return_value=mock_proc): - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") - mock_proc.wait.side_effect = [None] - - with patch('os.makedirs'): - with patch('jumpstarter_driver_ssh_mount.client.TcpPortforwardAdapter') as mock_adapter: - mock_adapter.return_value.__enter__ = MagicMock(return_value=("127.0.0.1", 22)) - mock_adapter.return_value.__exit__ = MagicMock(return_value=None) + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + mock_proc.wait.side_effect = [None] - with pytest.raises(Exception, match="sshfs mount failed"): - client.mount("/tmp/test-mount", extra_args=["reconnect", "cache=yes"]) + with pytest.raises(Exception, match="sshfs mount failed"): + client.mount("/tmp/test-mount", extra_args=["reconnect", "cache=yes"]) - test_run_args = mock_run.call_args_list[0][0][0] - # Each extra arg should be preceded by -o - for extra in ["reconnect", "cache=yes"]: - idx = test_run_args.index(extra) - assert test_run_args[idx - 1] == "-o", \ - f"Extra arg '{extra}' not preceded by '-o'" + test_run_args = mock_run.call_args_list[0][0][0] + for extra in ["reconnect", "cache=yes"]: + idx = test_run_args.index(extra) + assert test_run_args[idx - 1] == "-o" -def test_mount_sshfs_not_mounted_after_startup(): - """Test that mount fails if sshfs starts but mountpoint is not actually mounted""" - instance = SSHMount( - children={"ssh": _make_ssh_child()}, - ) - - with serve(instance) as client: +def test_mount_port_22_omits_p_flag(mount_instance, mock_portforward_22): + """Test that port 22 does not add -p flag to sshfs args""" + with serve(mount_instance) as client: mock_proc = MagicMock() - mock_proc.poll.return_value = None # Still running - mock_proc.wait.side_effect = [ - subprocess.TimeoutExpired("sshfs", 1), # Startup check - still running - None, # Cleanup wait after terminate - ] - mock_proc.returncode = 0 + mock_proc.poll.return_value = 0 + mock_proc.stderr = None - with patch.object(client, '_find_executable', return_value="/usr/bin/sshfs"): - with patch('subprocess.run') as mock_run: - with patch('subprocess.Popen', return_value=mock_proc): - mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + mock_proc.wait.side_effect = [None] - with patch('os.makedirs'): - with patch('os.path.ismount', return_value=False): - with patch('jumpstarter_driver_ssh_mount.client.TcpPortforwardAdapter') as mock_adapter: - mock_adapter.return_value.__enter__ = MagicMock(return_value=("127.0.0.1", 22)) - mock_adapter.return_value.__exit__ = MagicMock(return_value=None) + with pytest.raises(Exception, match="sshfs mount failed"): + client.mount("/tmp/test-mount") - with pytest.raises(Exception, match="is not mounted"): - client.mount("/tmp/test-mount", foreground=True) + test_run_args = mock_run.call_args_list[0][0][0] + assert "-p" not in test_run_args - # sshfs should have been terminated - mock_proc.terminate.assert_called() +# --------------------------------------------------------------------------- +# Subshell tests +# --------------------------------------------------------------------------- -def test_subshell_bad_shell_raises_click_exception(): +def test_subshell_bad_shell_raises_click_exception(mount_instance): """Test that _run_subshell raises ClickException when shell binary is not found""" - instance = SSHMount( - children={"ssh": _make_ssh_child()}, - ) - - with serve(instance) as client: + with serve(mount_instance) as client: with patch.dict(os.environ, {"SHELL": "/nonexistent/shell"}): with patch('subprocess.run', side_effect=FileNotFoundError("No such file")): with pytest.raises(Exception, match="Shell .* not found"): diff --git a/python/packages/jumpstarter-driver-ssh-mount/pyproject.toml b/python/packages/jumpstarter-driver-ssh-mount/pyproject.toml index c2264bfde..f8e3b0bd2 100644 --- a/python/packages/jumpstarter-driver-ssh-mount/pyproject.toml +++ b/python/packages/jumpstarter-driver-ssh-mount/pyproject.toml @@ -11,7 +11,6 @@ requires-python = ">=3.11" dependencies = [ "click>=8.0.0", "jumpstarter", - "jumpstarter-driver-composite", "jumpstarter-driver-network", "jumpstarter-driver-ssh", ] diff --git a/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/_ssh_utils.py b/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/_ssh_utils.py new file mode 100644 index 000000000..96fd3739f --- /dev/null +++ b/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/_ssh_utils.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import os +import tempfile + + +def create_temp_identity_file(ssh_identity: str, logger) -> str | None: + if not ssh_identity: + return None + + fd = None + temp_path = None + try: + fd, temp_path = tempfile.mkstemp(suffix="_ssh_key") + os.write(fd, ssh_identity.encode()) + os.close(fd) + fd = None + logger.debug("Created temporary identity file: %s", temp_path) + return temp_path + except Exception as e: + logger.error("Failed to create temporary identity file: %s", e) + if fd is not None: + try: + os.close(fd) + except Exception: + pass + if temp_path: + try: + os.unlink(temp_path) + except Exception: + pass + raise + + +def cleanup_identity_file(identity_file: str | None, logger) -> None: + if identity_file: + try: + os.unlink(identity_file) + logger.debug("Cleaned up temporary identity file: %s", identity_file) + except Exception as e: + logger.warning("Failed to clean up identity file %s: %s", identity_file, e) diff --git a/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py b/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py index 5574dcc1a..e47ef92a2 100644 --- a/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py +++ b/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py @@ -1,7 +1,5 @@ -import os import shlex import subprocess -import tempfile from contextlib import asynccontextmanager from dataclasses import dataclass from urllib.parse import urlparse @@ -10,6 +8,7 @@ from jumpstarter_driver_composite.client import CompositeClient from jumpstarter_driver_network.adapters import TcpPortforwardAdapter +from ._ssh_utils import cleanup_identity_file, create_temp_identity_file from jumpstarter.client.core import DriverMethodNotImplemented from jumpstarter.client.decorators import driver_click_command @@ -151,27 +150,7 @@ def run(self, options: SSHCommandRunOptions, args) -> SSHCommandRunResult: def _run_ssh_local(self, host, port, options, args): """Run SSH command with the given host, port, and arguments""" - # Create temporary identity file if needed - ssh_identity = self.identity - identity_file = None - temp_file = None - if ssh_identity: - try: - temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_ssh_key') - temp_file.write(ssh_identity) - temp_file.close() - # Set proper permissions (600) for SSH key - os.chmod(temp_file.name, 0o600) - identity_file = temp_file.name - self.logger.debug("Created temporary identity file: %s", identity_file) - except Exception as e: - self.logger.error("Failed to create temporary identity file: %s", e) - if temp_file: - try: - os.unlink(temp_file.name) - except Exception: - pass - raise + identity_file = create_temp_identity_file(self.identity, self.logger) try: # Build SSH command arguments @@ -186,13 +165,7 @@ def _run_ssh_local(self, host, port, options, args): # Execute the command return self._execute_ssh_command(ssh_args, options) finally: - # Clean up temporary identity file - if identity_file: - try: - os.unlink(identity_file) - self.logger.debug("Cleaned up temporary identity file: %s", identity_file) - except Exception as e: - self.logger.warning("Failed to clean up temporary identity file %s: %s", identity_file, str(e)) + cleanup_identity_file(identity_file, self.logger) def _build_ssh_command_args(self, port, identity_file, args): """Build initial SSH command arguments""" diff --git a/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py b/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py index 92a540406..64b05ac5b 100644 --- a/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py +++ b/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py @@ -585,6 +585,9 @@ def test_ssh_command_without_identity(): assert result.stdout == "some stdout" +_UTILS = "jumpstarter_driver_ssh._ssh_utils" + + def test_ssh_identity_temp_file_creation_and_cleanup(): """Test that temporary identity file is created and cleaned up properly""" instance = SSHWrapper( @@ -597,33 +600,23 @@ def test_ssh_identity_temp_file_creation_and_cleanup(): with patch('subprocess.run') as mock_run: mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="") - with patch('tempfile.NamedTemporaryFile') as mock_temp_file: - with patch('os.chmod') as mock_chmod: - with patch('os.unlink') as mock_unlink: - # Mock the temporary file - mock_temp_file_instance = MagicMock() - mock_temp_file_instance.name = "/tmp/test_ssh_key_12345" - mock_temp_file_instance.write = MagicMock() - mock_temp_file_instance.close = MagicMock() - mock_temp_file.return_value = mock_temp_file_instance - - # Test SSH command with identity - result = client.run(SSHCommandRunOptions(direct=False), ["hostname"]) - assert isinstance(result, SSHCommandRunResult) - - # Verify temporary file was created - mock_temp_file.assert_called_once_with(mode='w', delete=False, suffix='_ssh_key') - mock_temp_file_instance.write.assert_called_once_with(TEST_SSH_KEY) - mock_temp_file_instance.close.assert_called_once() - - # Verify proper permissions were set - mock_chmod.assert_called_once_with("/tmp/test_ssh_key_12345", 0o600) + mkstemp_rv = (5, "/tmp/test_ssh_key_12345") + with ( + patch(f"{_UTILS}.tempfile.mkstemp", return_value=mkstemp_rv) as mock_mkstemp, + patch(f"{_UTILS}.os.write") as mock_write, + patch(f"{_UTILS}.os.close") as mock_close, + patch(f"{_UTILS}.os.unlink") as mock_unlink, + ): + result = client.run(SSHCommandRunOptions(direct=False), ["hostname"]) + assert isinstance(result, SSHCommandRunResult) - # Verify temporary file was cleaned up - mock_unlink.assert_called_once_with("/tmp/test_ssh_key_12345") + mock_mkstemp.assert_called_once_with(suffix="_ssh_key") + mock_write.assert_called_once_with(5, TEST_SSH_KEY.encode()) + mock_close.assert_called_once_with(5) + mock_unlink.assert_called_once_with("/tmp/test_ssh_key_12345") - assert result.return_code == 0 - assert result.stdout == "some stdout" + assert result.return_code == 0 + assert result.stdout == "some stdout" def test_ssh_identity_temp_file_creation_error(): @@ -638,16 +631,46 @@ def test_ssh_identity_temp_file_creation_error(): with patch('subprocess.run') as mock_run: mock_run.return_value = MagicMock(returncode=0) - with patch('tempfile.NamedTemporaryFile') as mock_temp_file: - mock_temp_file.side_effect = OSError("Permission denied") + with patch(f"{_UTILS}.tempfile.mkstemp") as mock_mkstemp: + mock_mkstemp.side_effect = OSError("Permission denied") + + with pytest.raises(ExceptionGroup) as exc_info: + client.run(SSHCommandRunOptions(direct=False), ["hostname"]) + + assert any( + isinstance(e, OSError) and "Permission denied" in str(e) + for e in exc_info.value.exceptions + ) + + +def test_ssh_identity_temp_file_creation_error_fd_cleanup(): + """Test that fd is closed when write fails after mkstemp succeeds""" + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity=TEST_SSH_KEY + ) + + with serve(instance) as client: + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0) - # Test SSH command with identity should raise an error - # The exception will be wrapped in an ExceptionGroup due to the context manager + mkstemp_rv = (5, "/tmp/test_ssh_key_12345") + with ( + patch(f"{_UTILS}.tempfile.mkstemp", return_value=mkstemp_rv), + patch(f"{_UTILS}.os.write", side_effect=OSError("Disk full")), + patch(f"{_UTILS}.os.close") as mock_close, + patch(f"{_UTILS}.os.unlink") as mock_unlink, + ): with pytest.raises(ExceptionGroup) as exc_info: client.run(SSHCommandRunOptions(direct=False), ["hostname"]) - # Check that the original OSError is in the exception group - assert any(isinstance(e, OSError) and "Permission denied" in str(e) for e in exc_info.value.exceptions) + assert any( + isinstance(e, OSError) and "Disk full" in str(e) + for e in exc_info.value.exceptions + ) + mock_close.assert_called_once_with(5) + mock_unlink.assert_called_once_with("/tmp/test_ssh_key_12345") def test_ssh_identity_temp_file_cleanup_error(): @@ -662,36 +685,24 @@ def test_ssh_identity_temp_file_cleanup_error(): with patch('subprocess.run') as mock_run: mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="") - with patch('tempfile.NamedTemporaryFile') as mock_temp_file: - with patch('os.chmod') as mock_chmod: - with patch('os.unlink') as mock_unlink: - # Mock the temporary file - mock_temp_file_instance = MagicMock() - mock_temp_file_instance.name = "/tmp/test_ssh_key_12345" - mock_temp_file_instance.write = MagicMock() - mock_temp_file_instance.close = MagicMock() - mock_temp_file.return_value = mock_temp_file_instance - - # Mock cleanup failure - mock_unlink.side_effect = OSError("Permission denied") - - # Test SSH command with identity - should still succeed but log warning - with patch.object(client, 'logger') as mock_logger: - result = client.run(SSHCommandRunOptions(direct=False), ["hostname"]) - assert isinstance(result, SSHCommandRunResult) - - # Verify chmod was called - mock_chmod.assert_called_once_with("/tmp/test_ssh_key_12345", 0o600) - - # Verify warning was logged - mock_logger.warning.assert_called_once_with( - "Failed to clean up temporary identity file %s: %s", - "/tmp/test_ssh_key_12345", - str(mock_unlink.side_effect) - ) - - assert result.return_code == 0 - assert result.stdout == "some stdout" + mkstemp_rv = (5, "/tmp/test_ssh_key_12345") + with ( + patch(f"{_UTILS}.tempfile.mkstemp", return_value=mkstemp_rv), + patch(f"{_UTILS}.os.write"), + patch(f"{_UTILS}.os.close"), + patch(f"{_UTILS}.os.unlink", side_effect=OSError("Permission denied")), + ): + with patch.object(client, 'logger') as mock_logger: + result = client.run(SSHCommandRunOptions(direct=False), ["hostname"]) + assert isinstance(result, SSHCommandRunResult) + + mock_logger.warning.assert_called_once() + warning_args = mock_logger.warning.call_args[0] + assert "Failed to clean up identity file" in warning_args[0] + assert "/tmp/test_ssh_key_12345" in warning_args[1] + + assert result.return_code == 0 + assert result.stdout == "some stdout" def test_ssh_client_properties(): diff --git a/python/uv.lock b/python/uv.lock index 234253413..d8a8f317f 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.11" resolution-markers = [ "python_full_version >= '3.14'", @@ -33,6 +33,7 @@ members = [ "jumpstarter-driver-iscsi", "jumpstarter-driver-mitmproxy", "jumpstarter-driver-network", + "jumpstarter-driver-noyito-relay", "jumpstarter-driver-opendal", "jumpstarter-driver-pi-pico", "jumpstarter-driver-power", @@ -46,6 +47,7 @@ members = [ "jumpstarter-driver-someip", "jumpstarter-driver-ssh", "jumpstarter-driver-ssh-mitm", + "jumpstarter-driver-ssh-mount", "jumpstarter-driver-tasmota", "jumpstarter-driver-tftp", "jumpstarter-driver-tmt", @@ -1680,6 +1682,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/08/e7/ae38d7a6dfba0533684e0b2136817d667588ae3ec984c1a4e5df5eb88482/hatchling-1.27.0-py3-none-any.whl", hash = "sha256:d3a2f3567c4f926ea39849cdf924c7e99e6686c9c8e288ae1037c8fa2a5d937b", size = 75794, upload-time = "2024-12-15T17:08:10.364Z" }, ] +[[package]] +name = "hid" +version = "1.0.9" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e9/f8/0357a8aa8874a243e96d08a8568efaf7478293e1a3441ddca18039b690c1/hid-1.0.9.tar.gz", hash = "sha256:f4471f11f0e176d1b0cb1b243e55498cc90347a3aede735655304395694ac182", size = 4973, upload-time = "2026-02-05T15:35:20.595Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b8/c7/f0e1ad95179f44a6fc7a9140be025812cc7a62cf7390442b685a57ee1417/hid-1.0.9-py3-none-any.whl", hash = "sha256:6b9289e00bbc1e1589bec0c7f376a63fe03a4a4a1875575d0ad60e3e11a349f4", size = 4959, upload-time = "2026-02-05T15:35:19.269Z" }, +] + [[package]] name = "hpack" version = "4.1.0" @@ -2263,12 +2274,15 @@ source = { editable = "packages/jumpstarter-driver-ble" } dependencies = [ { name = "anyio" }, { name = "bleak" }, + { name = "click" }, { name = "jumpstarter" }, + { name = "jumpstarter-driver-network" }, ] [package.dev-dependencies] dev = [ { name = "pytest" }, + { name = "pytest-anyio" }, { name = "pytest-cov" }, ] @@ -2276,12 +2290,15 @@ dev = [ requires-dist = [ { name = "anyio", specifier = ">=4.10.0" }, { name = "bleak", specifier = ">=1.1.1" }, + { name = "click", specifier = ">=8.1.8" }, { name = "jumpstarter", editable = "packages/jumpstarter" }, + { name = "jumpstarter-driver-network", editable = "packages/jumpstarter-driver-network" }, ] [package.metadata.requires-dev] dev = [ { name = "pytest", specifier = ">=8.3.3" }, + { name = "pytest-anyio", specifier = ">=0.0.0" }, { name = "pytest-cov", specifier = ">=6.0.0" }, ] @@ -2753,6 +2770,38 @@ dev = [ { name = "websocket-client", specifier = ">=1.8.0" }, ] +[[package]] +name = "jumpstarter-driver-noyito-relay" +source = { editable = "packages/jumpstarter-driver-noyito-relay" } +dependencies = [ + { name = "hid" }, + { name = "jumpstarter" }, + { name = "jumpstarter-driver-power" }, + { name = "pyserial" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pytest" }, + { name = "pytest-cov" }, + { name = "pytest-mock" }, +] + +[package.metadata] +requires-dist = [ + { name = "hid", specifier = ">=1.0.4" }, + { name = "jumpstarter", editable = "packages/jumpstarter" }, + { name = "jumpstarter-driver-power", editable = "packages/jumpstarter-driver-power" }, + { name = "pyserial", specifier = ">=3.5" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pytest", specifier = ">=8.3.3" }, + { name = "pytest-cov", specifier = ">=6.0.0" }, + { name = "pytest-mock", specifier = ">=3.14.0" }, +] + [[package]] name = "jumpstarter-driver-opendal" source = { editable = "packages/jumpstarter-driver-opendal" } @@ -3094,7 +3143,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "jumpstarter", editable = "packages/jumpstarter" }, - { name = "opensomeip", specifier = ">=0.1.2" }, + { name = "opensomeip", specifier = ">=0.1.2,<0.2.0" }, ] [package.metadata.requires-dev] @@ -3167,6 +3216,36 @@ dev = [ { name = "trio", specifier = ">=0.28.0" }, ] +[[package]] +name = "jumpstarter-driver-ssh-mount" +source = { editable = "packages/jumpstarter-driver-ssh-mount" } +dependencies = [ + { name = "click" }, + { name = "jumpstarter" }, + { name = "jumpstarter-driver-network" }, + { name = "jumpstarter-driver-ssh" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pytest" }, + { name = "pytest-cov" }, +] + +[package.metadata] +requires-dist = [ + { name = "click", specifier = ">=8.0.0" }, + { name = "jumpstarter", editable = "packages/jumpstarter" }, + { name = "jumpstarter-driver-network", editable = "packages/jumpstarter-driver-network" }, + { name = "jumpstarter-driver-ssh", editable = "packages/jumpstarter-driver-ssh" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pytest", specifier = ">=8.3.3" }, + { name = "pytest-cov", specifier = ">=6.0.0" }, +] + [[package]] name = "jumpstarter-driver-tasmota" source = { editable = "packages/jumpstarter-driver-tasmota" } @@ -5257,6 +5336,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0d/d2/dfc2f25f3905921c2743c300a48d9494d29032f1389fc142e718d6978fb2/pytest_httpserver-1.1.3-py3-none-any.whl", hash = "sha256:5f84757810233e19e2bb5287f3826a71c97a3740abe3a363af9155c0f82fdbb9", size = 21000, upload-time = "2025-04-10T08:17:13.906Z" }, ] +[[package]] +name = "pytest-mock" +version = "3.15.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/68/14/eb014d26be205d38ad5ad20d9a80f7d201472e08167f0bb4361e251084a9/pytest_mock-3.15.1.tar.gz", hash = "sha256:1849a238f6f396da19762269de72cb1814ab44416fa73a8686deac10b0d87a0f", size = 34036, upload-time = "2025-09-16T16:37:27.081Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" }, +] + [[package]] name = "pytest-mqtt" version = "0.5.0"