Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.

Commit 67ff077

Browse files
add ssh identity configuration
specify identity to use with ssh_identity_file or directly using ssh_identity string Temporary file with that content will be created on the client side and passed to ssh -i option
1 parent 6266878 commit 67ff077

2 files changed

Lines changed: 70 additions & 15 deletions

File tree

packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import os
12
import shlex
23
import subprocess
4+
import tempfile
35
from dataclasses import dataclass
46
from urllib.parse import urlparse
57

@@ -45,6 +47,7 @@ def run(self, direct, args):
4547
# Get SSH command and default username from driver
4648
ssh_command = self.call("get_ssh_command")
4749
default_username = self.call("get_default_username")
50+
ssh_identity = self.call("get_ssh_identity")
4851

4952
if direct:
5053
# Use direct TCP address
@@ -56,7 +59,7 @@ def run(self, direct, args):
5659
if not host or not port:
5760
raise ValueError(f"Invalid address format: {address}")
5861
self.logger.debug(f"Using direct TCP connection for SSH - host: {host}, port: {port}")
59-
return self._run_ssh_local(host, port, ssh_command, default_username, args)
62+
return self._run_ssh_local(host, port, ssh_command, default_username, ssh_identity, args)
6063
except (DriverMethodNotImplemented, ValueError) as e:
6164
self.logger.error(f"Direct address connection failed ({e}), falling back to SSH port forwarding")
6265
return self.run(False, args)
@@ -69,27 +72,61 @@ def run(self, direct, args):
6972
host = addr[0]
7073
port = addr[1]
7174
self.logger.debug(f"SSH port forward established - host: {host}, port: {port}")
72-
return self._run_ssh_local(host, port, ssh_command, default_username, args)
75+
return self._run_ssh_local(host, port, ssh_command, default_username, ssh_identity, args)
7376

74-
def _run_ssh_local(self, host, port, ssh_command, default_username, args):
77+
def _run_ssh_local(self, host, port, ssh_command, default_username, ssh_identity, args):
7578
"""Run SSH command with the given host, port, and arguments"""
76-
# Build SSH command arguments
77-
ssh_args = self._build_ssh_command_args(ssh_command, port, default_username, args)
78-
79-
# Separate SSH options from command arguments
80-
ssh_options, command_args = self._separate_ssh_options_and_command_args(args)
81-
82-
# Build final SSH command
83-
ssh_args = self._build_final_ssh_command(ssh_args, ssh_options, host, command_args)
84-
85-
# Execute the command
86-
return self._execute_ssh_command(ssh_args)
79+
# Create temporary identity file if needed
80+
identity_file = None
81+
temp_file = None
82+
if ssh_identity:
83+
try:
84+
temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_ssh_key')
85+
temp_file.write(ssh_identity)
86+
temp_file.close()
87+
# Set proper permissions (600) for SSH key
88+
os.chmod(temp_file.name, 0o600)
89+
identity_file = temp_file.name
90+
self.logger.debug(f"Created temporary identity file: {identity_file}")
91+
except Exception as e:
92+
self.logger.error(f"Failed to create temporary identity file: {e}")
93+
if temp_file:
94+
try:
95+
os.unlink(temp_file.name)
96+
except Exception:
97+
pass
98+
raise
8799

88-
def _build_ssh_command_args(self, ssh_command, port, default_username, args):
100+
try:
101+
# Build SSH command arguments
102+
ssh_args = self._build_ssh_command_args(ssh_command, port, default_username, identity_file, args)
103+
104+
# Separate SSH options from command arguments
105+
ssh_options, command_args = self._separate_ssh_options_and_command_args(args)
106+
107+
# Build final SSH command
108+
ssh_args = self._build_final_ssh_command(ssh_args, ssh_options, host, command_args)
109+
110+
# Execute the command
111+
return self._execute_ssh_command(ssh_args)
112+
finally:
113+
# Clean up temporary identity file
114+
if identity_file:
115+
try:
116+
os.unlink(identity_file)
117+
self.logger.debug(f"Cleaned up temporary identity file: {identity_file}")
118+
except Exception as e:
119+
self.logger.warning(f"Failed to clean up temporary identity file {identity_file}: {e}")
120+
121+
def _build_ssh_command_args(self, ssh_command, port, default_username, identity_file, args):
89122
"""Build initial SSH command arguments"""
90123
# Split the SSH command into individual arguments
91124
ssh_args = shlex.split(ssh_command)
92125

126+
# Add identity file if provided
127+
if identity_file:
128+
ssh_args.extend(["-i", identity_file])
129+
93130
# Add port if specified
94131
if port and port != 22:
95132
ssh_args.extend(["-p", str(port)])

packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass
2+
from pathlib import Path
23

34
from jumpstarter.common.exceptions import ConfigurationError
45
from jumpstarter.driver import Driver, export
@@ -10,6 +11,8 @@ class SSHWrapper(Driver):
1011

1112
default_username: str = ""
1213
ssh_command: str = "ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o LogLevel=ERROR"
14+
ssh_identity: str | None = None
15+
ssh_identity_file: str | None = None
1316

1417
def __post_init__(self):
1518
if hasattr(super(), "__post_init__"):
@@ -18,6 +21,16 @@ def __post_init__(self):
1821
if "tcp" not in self.children:
1922
raise ConfigurationError("'tcp' child is required via ref, or directly as a TcpNetwork driver instance")
2023

24+
if self.ssh_identity and self.ssh_identity_file:
25+
raise ConfigurationError("Cannot specify both ssh_identity and ssh_identity_file")
26+
27+
# If ssh_identity_file is provided, read it into ssh_identity
28+
if self.ssh_identity_file:
29+
try:
30+
self.ssh_identity = Path(self.ssh_identity_file).read_text()
31+
except Exception as e:
32+
raise ConfigurationError(f"Failed to read ssh_identity_file '{self.ssh_identity_file}': {e}") from None
33+
2134
@classmethod
2235
def client(cls) -> str:
2336
return "jumpstarter_driver_ssh.client.SSHWrapperClient"
@@ -31,3 +44,8 @@ def get_default_username(self):
3144
def get_ssh_command(self):
3245
"""Get the SSH command to use"""
3346
return self.ssh_command
47+
48+
@export
49+
def get_ssh_identity(self):
50+
"""Get the SSH identity key content"""
51+
return self.ssh_identity

0 commit comments

Comments
 (0)