From 6fccabc40cfe34ba3510f644f38f27a1abdf94e6 Mon Sep 17 00:00:00 2001 From: Miguel Angel Ajo Pelayo Date: Fri, 3 Oct 2025 14:01:37 +0000 Subject: [PATCH] Test the jumpstarter-driver-ssh identity key injection --- .../jumpstarter_driver_ssh/driver_test.py | 299 ++++++++++++++++++ 1 file changed, 299 insertions(+) diff --git a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py index 4501828c9..0533c5a65 100644 --- a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py +++ b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py @@ -10,6 +10,13 @@ from jumpstarter.common.exceptions import ConfigurationError from jumpstarter.common.utils import serve +# Test SSH key content used in multiple tests +TEST_SSH_KEY = ( + "-----BEGIN OPENSSH PRIVATE KEY-----\n" + "test-key-content\n" + "-----END OPENSSH PRIVATE KEY-----" +) + def test_ssh_wrapper_defaults(): """Test SSH wrapper with default configuration""" @@ -348,3 +355,295 @@ def test_ssh_command_with_command_l_flag_does_not_interfere_with_username_inject assert ssh_l_index < hostname_index < command_l_index assert result == 0 + + +def test_ssh_identity_string_configuration(): + """Test SSH wrapper with ssh_identity string configuration""" + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity=TEST_SSH_KEY + ) + + # Test that the instance was created correctly + assert instance.ssh_identity == TEST_SSH_KEY + assert instance.ssh_identity_file is None + + # Test that the client class is correct + assert instance.client() == "jumpstarter_driver_ssh.client.SSHWrapperClient" + + +def test_ssh_identity_file_configuration(): + """Test SSH wrapper with ssh_identity_file configuration""" + import os + import tempfile + + # Create a temporary file with SSH key content + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_test_key') as temp_file: + temp_file.write(TEST_SSH_KEY) + temp_file_path = temp_file.name + + try: + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity_file=temp_file_path + ) + + # Test that the instance was created correctly + assert instance.ssh_identity == TEST_SSH_KEY + assert instance.ssh_identity_file == temp_file_path + + # Test that the client class is correct + assert instance.client() == "jumpstarter_driver_ssh.client.SSHWrapperClient" + finally: + # Clean up the temporary file + os.unlink(temp_file_path) + + +def test_ssh_identity_validation_error(): + """Test SSH wrapper raises error when both ssh_identity and ssh_identity_file are provided""" + with pytest.raises(ConfigurationError, match="Cannot specify both ssh_identity and ssh_identity_file"): + SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity="test-key-content", + ssh_identity_file="/path/to/key" + ) + + +def test_ssh_identity_file_read_error(): + """Test SSH wrapper raises error when ssh_identity_file cannot be read""" + with pytest.raises(ConfigurationError, match="Failed to read ssh_identity_file"): + SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity_file="/nonexistent/path/to/key" + ) + + +def test_ssh_command_with_identity_string(): + """Test SSH command execution with ssh_identity string""" + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity=TEST_SSH_KEY + ) + + with serve(instance) as client: + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0) + + # Test SSH command with identity string + result = client.run(False, ["hostname"]) + + # Verify subprocess.run was called + assert mock_run.called + call_args = mock_run.call_args[0][0] # First positional argument + + # Should include -i flag with temporary identity file + assert "-i" in call_args + identity_file_index = call_args.index("-i") + identity_file_path = call_args[identity_file_index + 1] + + # The identity file should be a temporary file + assert identity_file_path.endswith("_ssh_key") + assert "/tmp" in identity_file_path or "/var/tmp" in identity_file_path + + # Should include -l testuser + assert "-l" in call_args + assert "testuser" in call_args + + # Should include the actual hostname (127.0.0.1) at the end + assert "127.0.0.1" in call_args + assert "hostname" in call_args + + assert result == 0 + + +def test_ssh_command_with_identity_file(): + """Test SSH command execution with ssh_identity_file""" + import os + import tempfile + + # Create a temporary file with SSH key content + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_test_key') as temp_file: + temp_file.write(TEST_SSH_KEY) + temp_file_path = temp_file.name + + try: + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity_file=temp_file_path + ) + + with serve(instance) as client: + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0) + + # Test SSH command with identity file + result = client.run(False, ["hostname"]) + + # Verify subprocess.run was called + assert mock_run.called + call_args = mock_run.call_args[0][0] # First positional argument + + # Should include -i flag with temporary identity file + assert "-i" in call_args + identity_file_index = call_args.index("-i") + identity_file_path = call_args[identity_file_index + 1] + + # The identity file should be a temporary file (not the original file) + assert identity_file_path.endswith("_ssh_key") + assert "/tmp" in identity_file_path or "/var/tmp" in identity_file_path + assert identity_file_path != temp_file_path + + # Should include -l testuser + assert "-l" in call_args + assert "testuser" in call_args + + # Should include the actual hostname (127.0.0.1) at the end + assert "127.0.0.1" in call_args + assert "hostname" in call_args + + assert result == 0 + finally: + # Clean up the temporary file + os.unlink(temp_file_path) + + +def test_ssh_command_without_identity(): + """Test SSH command execution without identity (should not include -i flag)""" + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser" + ) + + with serve(instance) as client: + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0) + + # Test SSH command without identity + result = client.run(False, ["hostname"]) + + # Verify subprocess.run was called + assert mock_run.called + call_args = mock_run.call_args[0][0] # First positional argument + + # Should NOT include -i flag + assert "-i" not in call_args + + # Should include -l testuser + assert "-l" in call_args + assert "testuser" in call_args + + # Should include the actual hostname (127.0.0.1) at the end + assert "127.0.0.1" in call_args + assert "hostname" in call_args + + assert result == 0 + + +def test_ssh_identity_temp_file_creation_and_cleanup(): + """Test that temporary identity file is created and cleaned up properly""" + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity=TEST_SSH_KEY + ) + + with serve(instance) as client: + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0) + + with patch('tempfile.NamedTemporaryFile') as mock_temp_file: + with patch('os.chmod') as mock_chmod: + with patch('os.unlink') as mock_unlink: + # Mock the temporary file + mock_temp_file_instance = MagicMock() + mock_temp_file_instance.name = "/tmp/test_ssh_key_12345" + mock_temp_file_instance.write = MagicMock() + mock_temp_file_instance.close = MagicMock() + mock_temp_file.return_value = mock_temp_file_instance + + # Test SSH command with identity + result = client.run(False, ["hostname"]) + + # Verify temporary file was created + mock_temp_file.assert_called_once_with(mode='w', delete=False, suffix='_ssh_key') + mock_temp_file_instance.write.assert_called_once_with(TEST_SSH_KEY) + mock_temp_file_instance.close.assert_called_once() + + # Verify proper permissions were set + mock_chmod.assert_called_once_with("/tmp/test_ssh_key_12345", 0o600) + + # Verify temporary file was cleaned up + mock_unlink.assert_called_once_with("/tmp/test_ssh_key_12345") + + assert result == 0 + + +def test_ssh_identity_temp_file_creation_error(): + """Test error handling when temporary identity file creation fails""" + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity=TEST_SSH_KEY + ) + + with serve(instance) as client: + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0) + + with patch('tempfile.NamedTemporaryFile') as mock_temp_file: + mock_temp_file.side_effect = OSError("Permission denied") + + # Test SSH command with identity should raise an error + # The exception will be wrapped in an ExceptionGroup due to the context manager + with pytest.raises(ExceptionGroup) as exc_info: + client.run(False, ["hostname"]) + + # Check that the original OSError is in the exception group + assert any(isinstance(e, OSError) and "Permission denied" in str(e) for e in exc_info.value.exceptions) + + +def test_ssh_identity_temp_file_cleanup_error(): + """Test error handling when temporary identity file cleanup fails""" + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity=TEST_SSH_KEY + ) + + with serve(instance) as client: + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0) + + with patch('tempfile.NamedTemporaryFile') as mock_temp_file: + with patch('os.chmod') as mock_chmod: + with patch('os.unlink') as mock_unlink: + # Mock the temporary file + mock_temp_file_instance = MagicMock() + mock_temp_file_instance.name = "/tmp/test_ssh_key_12345" + mock_temp_file_instance.write = MagicMock() + mock_temp_file_instance.close = MagicMock() + mock_temp_file.return_value = mock_temp_file_instance + + # Mock cleanup failure + mock_unlink.side_effect = OSError("Permission denied") + + # Test SSH command with identity - should still succeed but log warning + with patch.object(client, 'logger') as mock_logger: + result = client.run(False, ["hostname"]) + + # Verify chmod was called + mock_chmod.assert_called_once_with("/tmp/test_ssh_key_12345", 0o600) + + # Verify warning was logged + mock_logger.warning.assert_called_once() + warning_call = mock_logger.warning.call_args[0][0] + assert "Failed to clean up temporary identity file" in warning_call + assert "/tmp/test_ssh_key_12345" in warning_call + + assert result == 0