Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.

Commit 0ed4075

Browse files
authored
Merge pull request #694 from jumpstarter-dev/backport-687-to-release-0.7
[Backport release-0.7] add ssh identity configuration
2 parents d970e5c + 3576c0e commit 0ed4075

2 files changed

Lines changed: 70 additions & 15 deletions

File tree

packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import os
12
import shlex
23
import subprocess
4+
import tempfile
35
from dataclasses import dataclass
46
from urllib.parse import urlparse
57

@@ -49,6 +51,7 @@ def run(self, direct, args):
4951
# Get SSH command and default username from driver
5052
ssh_command = self.call("get_ssh_command")
5153
default_username = self.call("get_default_username")
54+
ssh_identity = self.call("get_ssh_identity")
5255

5356
if direct:
5457
# Use direct TCP address
@@ -60,7 +63,7 @@ def run(self, direct, args):
6063
if not host or not port:
6164
raise ValueError(f"Invalid address format: {address}")
6265
self.logger.debug(f"Using direct TCP connection for SSH - host: {host}, port: {port}")
63-
return self._run_ssh_local(host, port, ssh_command, default_username, args)
66+
return self._run_ssh_local(host, port, ssh_command, default_username, ssh_identity, args)
6467
except (DriverMethodNotImplemented, ValueError) as e:
6568
self.logger.error(f"Direct address connection failed ({e}), falling back to SSH port forwarding")
6669
return self.run(False, args)
@@ -73,27 +76,61 @@ def run(self, direct, args):
7376
host = addr[0]
7477
port = addr[1]
7578
self.logger.debug(f"SSH port forward established - host: {host}, port: {port}")
76-
return self._run_ssh_local(host, port, ssh_command, default_username, args)
79+
return self._run_ssh_local(host, port, ssh_command, default_username, ssh_identity, args)
7780

78-
def _run_ssh_local(self, host, port, ssh_command, default_username, args):
81+
def _run_ssh_local(self, host, port, ssh_command, default_username, ssh_identity, args):
7982
"""Run SSH command with the given host, port, and arguments"""
80-
# Build SSH command arguments
81-
ssh_args = self._build_ssh_command_args(ssh_command, port, default_username, args)
82-
83-
# Separate SSH options from command arguments
84-
ssh_options, command_args = self._separate_ssh_options_and_command_args(args)
85-
86-
# Build final SSH command
87-
ssh_args = self._build_final_ssh_command(ssh_args, ssh_options, host, command_args)
88-
89-
# Execute the command
90-
return self._execute_ssh_command(ssh_args)
83+
# Create temporary identity file if needed
84+
identity_file = None
85+
temp_file = None
86+
if ssh_identity:
87+
try:
88+
temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_ssh_key')
89+
temp_file.write(ssh_identity)
90+
temp_file.close()
91+
# Set proper permissions (600) for SSH key
92+
os.chmod(temp_file.name, 0o600)
93+
identity_file = temp_file.name
94+
self.logger.debug(f"Created temporary identity file: {identity_file}")
95+
except Exception as e:
96+
self.logger.error(f"Failed to create temporary identity file: {e}")
97+
if temp_file:
98+
try:
99+
os.unlink(temp_file.name)
100+
except Exception:
101+
pass
102+
raise
91103

92-
def _build_ssh_command_args(self, ssh_command, port, default_username, args):
104+
try:
105+
# Build SSH command arguments
106+
ssh_args = self._build_ssh_command_args(ssh_command, port, default_username, identity_file, args)
107+
108+
# Separate SSH options from command arguments
109+
ssh_options, command_args = self._separate_ssh_options_and_command_args(args)
110+
111+
# Build final SSH command
112+
ssh_args = self._build_final_ssh_command(ssh_args, ssh_options, host, command_args)
113+
114+
# Execute the command
115+
return self._execute_ssh_command(ssh_args)
116+
finally:
117+
# Clean up temporary identity file
118+
if identity_file:
119+
try:
120+
os.unlink(identity_file)
121+
self.logger.debug(f"Cleaned up temporary identity file: {identity_file}")
122+
except Exception as e:
123+
self.logger.warning(f"Failed to clean up temporary identity file {identity_file}: {e}")
124+
125+
def _build_ssh_command_args(self, ssh_command, port, default_username, identity_file, args):
93126
"""Build initial SSH command arguments"""
94127
# Split the SSH command into individual arguments
95128
ssh_args = shlex.split(ssh_command)
96129

130+
# Add identity file if provided
131+
if identity_file:
132+
ssh_args.extend(["-i", identity_file])
133+
97134
# Add port if specified
98135
if port and port != 22:
99136
ssh_args.extend(["-p", str(port)])

packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass
2+
from pathlib import Path
23

34
from jumpstarter.common.exceptions import ConfigurationError
45
from jumpstarter.driver import Driver, export
@@ -10,6 +11,8 @@ class SSHWrapper(Driver):
1011

1112
default_username: str = ""
1213
ssh_command: str = "ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o LogLevel=ERROR"
14+
ssh_identity: str | None = None
15+
ssh_identity_file: str | None = None
1316

1417
def __post_init__(self):
1518
if hasattr(super(), "__post_init__"):
@@ -18,6 +21,16 @@ def __post_init__(self):
1821
if "tcp" not in self.children:
1922
raise ConfigurationError("'tcp' child is required via ref, or directly as a TcpNetwork driver instance")
2023

24+
if self.ssh_identity and self.ssh_identity_file:
25+
raise ConfigurationError("Cannot specify both ssh_identity and ssh_identity_file")
26+
27+
# If ssh_identity_file is provided, read it into ssh_identity
28+
if self.ssh_identity_file:
29+
try:
30+
self.ssh_identity = Path(self.ssh_identity_file).read_text()
31+
except Exception as e:
32+
raise ConfigurationError(f"Failed to read ssh_identity_file '{self.ssh_identity_file}': {e}") from None
33+
2134
@classmethod
2235
def client(cls) -> str:
2336
return "jumpstarter_driver_ssh.client.SSHWrapperClient"
@@ -31,3 +44,8 @@ def get_default_username(self):
3144
def get_ssh_command(self):
3245
"""Get the SSH command to use"""
3346
return self.ssh_command
47+
48+
@export
49+
def get_ssh_identity(self):
50+
"""Get the SSH identity key content"""
51+
return self.ssh_identity

0 commit comments

Comments
 (0)