|
10 | 10 | from jumpstarter.common.exceptions import ConfigurationError |
11 | 11 | from jumpstarter.common.utils import serve |
12 | 12 |
|
| 13 | +# Test SSH key content used in multiple tests |
| 14 | +TEST_SSH_KEY = ( |
| 15 | + "-----BEGIN OPENSSH PRIVATE KEY-----\n" |
| 16 | + "test-key-content\n" |
| 17 | + "-----END OPENSSH PRIVATE KEY-----" |
| 18 | +) |
| 19 | + |
13 | 20 |
|
14 | 21 | def test_ssh_wrapper_defaults(): |
15 | 22 | """Test SSH wrapper with default configuration""" |
@@ -348,3 +355,295 @@ def test_ssh_command_with_command_l_flag_does_not_interfere_with_username_inject |
348 | 355 | assert ssh_l_index < hostname_index < command_l_index |
349 | 356 |
|
350 | 357 | assert result == 0 |
| 358 | + |
| 359 | + |
| 360 | +def test_ssh_identity_string_configuration(): |
| 361 | + """Test SSH wrapper with ssh_identity string configuration""" |
| 362 | + instance = SSHWrapper( |
| 363 | + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, |
| 364 | + default_username="testuser", |
| 365 | + ssh_identity=TEST_SSH_KEY |
| 366 | + ) |
| 367 | + |
| 368 | + # Test that the instance was created correctly |
| 369 | + assert instance.ssh_identity == TEST_SSH_KEY |
| 370 | + assert instance.ssh_identity_file is None |
| 371 | + |
| 372 | + # Test that the client class is correct |
| 373 | + assert instance.client() == "jumpstarter_driver_ssh.client.SSHWrapperClient" |
| 374 | + |
| 375 | + |
| 376 | +def test_ssh_identity_file_configuration(): |
| 377 | + """Test SSH wrapper with ssh_identity_file configuration""" |
| 378 | + import os |
| 379 | + import tempfile |
| 380 | + |
| 381 | + # Create a temporary file with SSH key content |
| 382 | + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_test_key') as temp_file: |
| 383 | + temp_file.write(TEST_SSH_KEY) |
| 384 | + temp_file_path = temp_file.name |
| 385 | + |
| 386 | + try: |
| 387 | + instance = SSHWrapper( |
| 388 | + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, |
| 389 | + default_username="testuser", |
| 390 | + ssh_identity_file=temp_file_path |
| 391 | + ) |
| 392 | + |
| 393 | + # Test that the instance was created correctly |
| 394 | + assert instance.ssh_identity == TEST_SSH_KEY |
| 395 | + assert instance.ssh_identity_file == temp_file_path |
| 396 | + |
| 397 | + # Test that the client class is correct |
| 398 | + assert instance.client() == "jumpstarter_driver_ssh.client.SSHWrapperClient" |
| 399 | + finally: |
| 400 | + # Clean up the temporary file |
| 401 | + os.unlink(temp_file_path) |
| 402 | + |
| 403 | + |
| 404 | +def test_ssh_identity_validation_error(): |
| 405 | + """Test SSH wrapper raises error when both ssh_identity and ssh_identity_file are provided""" |
| 406 | + with pytest.raises(ConfigurationError, match="Cannot specify both ssh_identity and ssh_identity_file"): |
| 407 | + SSHWrapper( |
| 408 | + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, |
| 409 | + default_username="testuser", |
| 410 | + ssh_identity="test-key-content", |
| 411 | + ssh_identity_file="/path/to/key" |
| 412 | + ) |
| 413 | + |
| 414 | + |
| 415 | +def test_ssh_identity_file_read_error(): |
| 416 | + """Test SSH wrapper raises error when ssh_identity_file cannot be read""" |
| 417 | + with pytest.raises(ConfigurationError, match="Failed to read ssh_identity_file"): |
| 418 | + SSHWrapper( |
| 419 | + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, |
| 420 | + default_username="testuser", |
| 421 | + ssh_identity_file="/nonexistent/path/to/key" |
| 422 | + ) |
| 423 | + |
| 424 | + |
| 425 | +def test_ssh_command_with_identity_string(): |
| 426 | + """Test SSH command execution with ssh_identity string""" |
| 427 | + instance = SSHWrapper( |
| 428 | + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, |
| 429 | + default_username="testuser", |
| 430 | + ssh_identity=TEST_SSH_KEY |
| 431 | + ) |
| 432 | + |
| 433 | + with serve(instance) as client: |
| 434 | + with patch('subprocess.run') as mock_run: |
| 435 | + mock_run.return_value = MagicMock(returncode=0) |
| 436 | + |
| 437 | + # Test SSH command with identity string |
| 438 | + result = client.run(False, ["hostname"]) |
| 439 | + |
| 440 | + # Verify subprocess.run was called |
| 441 | + assert mock_run.called |
| 442 | + call_args = mock_run.call_args[0][0] # First positional argument |
| 443 | + |
| 444 | + # Should include -i flag with temporary identity file |
| 445 | + assert "-i" in call_args |
| 446 | + identity_file_index = call_args.index("-i") |
| 447 | + identity_file_path = call_args[identity_file_index + 1] |
| 448 | + |
| 449 | + # The identity file should be a temporary file |
| 450 | + assert identity_file_path.endswith("_ssh_key") |
| 451 | + assert "/tmp" in identity_file_path or "/var/tmp" in identity_file_path |
| 452 | + |
| 453 | + # Should include -l testuser |
| 454 | + assert "-l" in call_args |
| 455 | + assert "testuser" in call_args |
| 456 | + |
| 457 | + # Should include the actual hostname (127.0.0.1) at the end |
| 458 | + assert "127.0.0.1" in call_args |
| 459 | + assert "hostname" in call_args |
| 460 | + |
| 461 | + assert result == 0 |
| 462 | + |
| 463 | + |
| 464 | +def test_ssh_command_with_identity_file(): |
| 465 | + """Test SSH command execution with ssh_identity_file""" |
| 466 | + import os |
| 467 | + import tempfile |
| 468 | + |
| 469 | + # Create a temporary file with SSH key content |
| 470 | + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_test_key') as temp_file: |
| 471 | + temp_file.write(TEST_SSH_KEY) |
| 472 | + temp_file_path = temp_file.name |
| 473 | + |
| 474 | + try: |
| 475 | + instance = SSHWrapper( |
| 476 | + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, |
| 477 | + default_username="testuser", |
| 478 | + ssh_identity_file=temp_file_path |
| 479 | + ) |
| 480 | + |
| 481 | + with serve(instance) as client: |
| 482 | + with patch('subprocess.run') as mock_run: |
| 483 | + mock_run.return_value = MagicMock(returncode=0) |
| 484 | + |
| 485 | + # Test SSH command with identity file |
| 486 | + result = client.run(False, ["hostname"]) |
| 487 | + |
| 488 | + # Verify subprocess.run was called |
| 489 | + assert mock_run.called |
| 490 | + call_args = mock_run.call_args[0][0] # First positional argument |
| 491 | + |
| 492 | + # Should include -i flag with temporary identity file |
| 493 | + assert "-i" in call_args |
| 494 | + identity_file_index = call_args.index("-i") |
| 495 | + identity_file_path = call_args[identity_file_index + 1] |
| 496 | + |
| 497 | + # The identity file should be a temporary file (not the original file) |
| 498 | + assert identity_file_path.endswith("_ssh_key") |
| 499 | + assert "/tmp" in identity_file_path or "/var/tmp" in identity_file_path |
| 500 | + assert identity_file_path != temp_file_path |
| 501 | + |
| 502 | + # Should include -l testuser |
| 503 | + assert "-l" in call_args |
| 504 | + assert "testuser" in call_args |
| 505 | + |
| 506 | + # Should include the actual hostname (127.0.0.1) at the end |
| 507 | + assert "127.0.0.1" in call_args |
| 508 | + assert "hostname" in call_args |
| 509 | + |
| 510 | + assert result == 0 |
| 511 | + finally: |
| 512 | + # Clean up the temporary file |
| 513 | + os.unlink(temp_file_path) |
| 514 | + |
| 515 | + |
| 516 | +def test_ssh_command_without_identity(): |
| 517 | + """Test SSH command execution without identity (should not include -i flag)""" |
| 518 | + instance = SSHWrapper( |
| 519 | + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, |
| 520 | + default_username="testuser" |
| 521 | + ) |
| 522 | + |
| 523 | + with serve(instance) as client: |
| 524 | + with patch('subprocess.run') as mock_run: |
| 525 | + mock_run.return_value = MagicMock(returncode=0) |
| 526 | + |
| 527 | + # Test SSH command without identity |
| 528 | + result = client.run(False, ["hostname"]) |
| 529 | + |
| 530 | + # Verify subprocess.run was called |
| 531 | + assert mock_run.called |
| 532 | + call_args = mock_run.call_args[0][0] # First positional argument |
| 533 | + |
| 534 | + # Should NOT include -i flag |
| 535 | + assert "-i" not in call_args |
| 536 | + |
| 537 | + # Should include -l testuser |
| 538 | + assert "-l" in call_args |
| 539 | + assert "testuser" in call_args |
| 540 | + |
| 541 | + # Should include the actual hostname (127.0.0.1) at the end |
| 542 | + assert "127.0.0.1" in call_args |
| 543 | + assert "hostname" in call_args |
| 544 | + |
| 545 | + assert result == 0 |
| 546 | + |
| 547 | + |
| 548 | +def test_ssh_identity_temp_file_creation_and_cleanup(): |
| 549 | + """Test that temporary identity file is created and cleaned up properly""" |
| 550 | + instance = SSHWrapper( |
| 551 | + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, |
| 552 | + default_username="testuser", |
| 553 | + ssh_identity=TEST_SSH_KEY |
| 554 | + ) |
| 555 | + |
| 556 | + with serve(instance) as client: |
| 557 | + with patch('subprocess.run') as mock_run: |
| 558 | + mock_run.return_value = MagicMock(returncode=0) |
| 559 | + |
| 560 | + with patch('tempfile.NamedTemporaryFile') as mock_temp_file: |
| 561 | + with patch('os.chmod') as mock_chmod: |
| 562 | + with patch('os.unlink') as mock_unlink: |
| 563 | + # Mock the temporary file |
| 564 | + mock_temp_file_instance = MagicMock() |
| 565 | + mock_temp_file_instance.name = "/tmp/test_ssh_key_12345" |
| 566 | + mock_temp_file_instance.write = MagicMock() |
| 567 | + mock_temp_file_instance.close = MagicMock() |
| 568 | + mock_temp_file.return_value = mock_temp_file_instance |
| 569 | + |
| 570 | + # Test SSH command with identity |
| 571 | + result = client.run(False, ["hostname"]) |
| 572 | + |
| 573 | + # Verify temporary file was created |
| 574 | + mock_temp_file.assert_called_once_with(mode='wb', delete=False, suffix='_ssh_key') |
| 575 | + mock_temp_file_instance.write.assert_called_once_with(TEST_SSH_KEY.encode('utf-8')) |
| 576 | + mock_temp_file_instance.close.assert_called_once() |
| 577 | + |
| 578 | + # Verify proper permissions were set |
| 579 | + mock_chmod.assert_called_once_with("/tmp/test_ssh_key_12345", 0o600) |
| 580 | + |
| 581 | + # Verify temporary file was cleaned up |
| 582 | + mock_unlink.assert_called_once_with("/tmp/test_ssh_key_12345") |
| 583 | + |
| 584 | + assert result == 0 |
| 585 | + |
| 586 | + |
| 587 | +def test_ssh_identity_temp_file_creation_error(): |
| 588 | + """Test error handling when temporary identity file creation fails""" |
| 589 | + instance = SSHWrapper( |
| 590 | + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, |
| 591 | + default_username="testuser", |
| 592 | + ssh_identity=TEST_SSH_KEY |
| 593 | + ) |
| 594 | + |
| 595 | + with serve(instance) as client: |
| 596 | + with patch('subprocess.run') as mock_run: |
| 597 | + mock_run.return_value = MagicMock(returncode=0) |
| 598 | + |
| 599 | + with patch('tempfile.NamedTemporaryFile') as mock_temp_file: |
| 600 | + mock_temp_file.side_effect = OSError("Permission denied") |
| 601 | + |
| 602 | + # Test SSH command with identity should raise an error |
| 603 | + # The exception will be wrapped in an ExceptionGroup due to the context manager |
| 604 | + with pytest.raises(ExceptionGroup) as exc_info: |
| 605 | + client.run(False, ["hostname"]) |
| 606 | + |
| 607 | + # Check that the original OSError is in the exception group |
| 608 | + assert any(isinstance(e, OSError) and "Permission denied" in str(e) for e in exc_info.value.exceptions) |
| 609 | + |
| 610 | + |
| 611 | +def test_ssh_identity_temp_file_cleanup_error(): |
| 612 | + """Test error handling when temporary identity file cleanup fails""" |
| 613 | + instance = SSHWrapper( |
| 614 | + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, |
| 615 | + default_username="testuser", |
| 616 | + ssh_identity=TEST_SSH_KEY |
| 617 | + ) |
| 618 | + |
| 619 | + with serve(instance) as client: |
| 620 | + with patch('subprocess.run') as mock_run: |
| 621 | + mock_run.return_value = MagicMock(returncode=0) |
| 622 | + |
| 623 | + with patch('tempfile.NamedTemporaryFile') as mock_temp_file: |
| 624 | + with patch('os.chmod') as mock_chmod: |
| 625 | + with patch('os.unlink') as mock_unlink: |
| 626 | + # Mock the temporary file |
| 627 | + mock_temp_file_instance = MagicMock() |
| 628 | + mock_temp_file_instance.name = "/tmp/test_ssh_key_12345" |
| 629 | + mock_temp_file_instance.write = MagicMock() |
| 630 | + mock_temp_file_instance.close = MagicMock() |
| 631 | + mock_temp_file.return_value = mock_temp_file_instance |
| 632 | + |
| 633 | + # Mock cleanup failure |
| 634 | + mock_unlink.side_effect = OSError("Permission denied") |
| 635 | + |
| 636 | + # Test SSH command with identity - should still succeed but log warning |
| 637 | + with patch.object(client, 'logger') as mock_logger: |
| 638 | + result = client.run(False, ["hostname"]) |
| 639 | + |
| 640 | + # Verify chmod was called |
| 641 | + mock_chmod.assert_called_once_with("/tmp/test_ssh_key_12345", 0o600) |
| 642 | + |
| 643 | + # Verify warning was logged |
| 644 | + mock_logger.warning.assert_called_once() |
| 645 | + warning_call = mock_logger.warning.call_args[0][0] |
| 646 | + assert "Failed to clean up temporary identity file" in warning_call |
| 647 | + assert "/tmp/test_ssh_key_12345" in warning_call |
| 648 | + |
| 649 | + assert result == 0 |
0 commit comments