Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.

Commit 8e91304

Browse files
committed
driver-ssh: pass options dataclass to run
Add SSHCommandRunOptions class and modify the run method signature to use it. This will make it easier to change the options before having to change the signature in the future. Signed-off-by: Albert Esteve <aesteve@redhat.com>
1 parent cc71afe commit 8e91304

2 files changed

Lines changed: 47 additions & 27 deletions

File tree

packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@ def from_completed_process(result: subprocess.CompletedProcess) -> "SSHCommandRu
2929
)
3030

3131

32+
@dataclass
33+
class SSHCommandRunOptions:
34+
"""Options for running an SSH command"""
35+
direct: bool = False
36+
capture_output: bool = True
37+
capture_as_text: bool = True
38+
39+
3240
@dataclass(kw_only=True)
3341
class SSHWrapperClient(CompositeClient):
3442
"""
@@ -46,7 +54,15 @@ def cli(self):
4654
@click.option("--direct", is_flag=True, help="Use direct TCP address")
4755
@click.argument("args", nargs=-1)
4856
def ssh(direct, args):
49-
result = self.run(direct, args)
57+
options = SSHCommandRunOptions(
58+
direct=direct,
59+
# When no args are provided, an interactive shell is implied.
60+
# In this case, we must not capture stdout/stderr so the shell
61+
# can interact with the terminal.
62+
capture_output=bool(args),
63+
)
64+
65+
result = self.run(options, args)
5066
self.logger.debug("SSH exit code: %s", result.return_code)
5167

5268
if result.stdout:
@@ -69,14 +85,14 @@ def stream(self, method="connect"):
6985
async def stream_async(self, method):
7086
return await self.tcp.stream_async(method)
7187

72-
def run(self, direct, args) -> SSHCommandRunResult:
88+
def run(self, options: SSHCommandRunOptions, args) -> SSHCommandRunResult:
7389
"""Run SSH command with the given parameters and arguments"""
7490
# Get SSH command and default username from driver
7591
ssh_command = self.call("get_ssh_command")
7692
default_username = self.call("get_default_username")
7793
ssh_identity = self.call("get_ssh_identity")
7894

79-
if direct:
95+
if options.direct:
8096
# Use direct TCP address
8197
try:
8298
address = self.tcp.address() # (format: "tcp://host:port")
@@ -86,10 +102,14 @@ def run(self, direct, args) -> SSHCommandRunResult:
86102
if not host or not port:
87103
raise ValueError(f"Invalid address format: {address}")
88104
self.logger.debug(f"Using direct TCP connection for SSH - host: {host}, port: {port}")
89-
return self._run_ssh_local(host, port, ssh_command, default_username, ssh_identity, args)
105+
return self._run_ssh_local(host, port, ssh_command, options, default_username, ssh_identity, args)
90106
except (DriverMethodNotImplemented, ValueError) as e:
91107
self.logger.error(f"Direct address connection failed ({e}), falling back to SSH port forwarding")
92-
return self.run(False, args)
108+
return self.run(SSHCommandRunOptions(
109+
direct=False,
110+
capture_output=options.capture_output,
111+
capture_as_text=options.capture_as_text,
112+
), args)
93113
else:
94114
# Use SSH port forwarding (default behavior)
95115
self.logger.debug("Using SSH port forwarding for SSH connection")
@@ -98,9 +118,9 @@ def run(self, direct, args) -> SSHCommandRunResult:
98118
) as addr:
99119
host, port = addr
100120
self.logger.debug(f"SSH port forward established - host: {host}, port: {port}")
101-
return self._run_ssh_local(host, port, ssh_command, default_username, ssh_identity, args)
121+
return self._run_ssh_local(host, port, ssh_command, options, default_username, ssh_identity, args)
102122

103-
def _run_ssh_local(self, host, port, ssh_command, default_username, ssh_identity, args):
123+
def _run_ssh_local(self, host, port, ssh_command, options, default_username, ssh_identity, args):
104124
"""Run SSH command with the given host, port, and arguments"""
105125
# Create temporary identity file if needed
106126
identity_file = None
@@ -134,7 +154,7 @@ def _run_ssh_local(self, host, port, ssh_command, default_username, ssh_identity
134154
ssh_args = self._build_final_ssh_command(ssh_args, ssh_options, host, command_args)
135155

136156
# Execute the command
137-
return self._execute_ssh_command(ssh_args)
157+
return self._execute_ssh_command(ssh_args, options)
138158
finally:
139159
# Clean up temporary identity file
140160
if identity_file:
@@ -234,10 +254,10 @@ def _build_final_ssh_command(self, ssh_args, ssh_options, host, command_args):
234254
self.logger.debug(f"Running SSH command: {ssh_args}")
235255
return ssh_args
236256

237-
def _execute_ssh_command(self, ssh_args) -> SSHCommandRunResult:
257+
def _execute_ssh_command(self, ssh_args, options: SSHCommandRunOptions) -> SSHCommandRunResult:
238258
"""Execute the SSH command and return the result"""
239259
try:
240-
result = subprocess.run(ssh_args, capture_output=True, text=True)
260+
result = subprocess.run(ssh_args, capture_output=options.capture_output, text=options.capture_as_text)
241261
return SSHCommandRunResult.from_completed_process(result)
242262
except FileNotFoundError:
243263
self.logger.error(

packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
from jumpstarter_driver_network.driver import TcpNetwork
77

8-
from jumpstarter_driver_ssh.client import SSHCommandRunResult
8+
from jumpstarter_driver_ssh.client import SSHCommandRunOptions, SSHCommandRunResult
99
from jumpstarter_driver_ssh.driver import SSHWrapper
1010

1111
from jumpstarter.common.exceptions import ConfigurationError
@@ -55,7 +55,7 @@ def test_ssh_command_with_default_username():
5555
mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="")
5656

5757
# Test SSH command with default username
58-
result = client.run(False, ["hostname"])
58+
result = client.run(SSHCommandRunOptions(direct=False), ["hostname"])
5959
assert isinstance(result, SSHCommandRunResult)
6060

6161
# Verify subprocess.run was called
@@ -87,7 +87,7 @@ def test_ssh_command_without_default_username():
8787
mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="")
8888

8989
# Test SSH command without default username
90-
result = client.run(False, ["hostname"])
90+
result = client.run(SSHCommandRunOptions(direct=False), ["hostname"])
9191
assert isinstance(result, SSHCommandRunResult)
9292

9393
# Verify subprocess.run was called
@@ -117,7 +117,7 @@ def test_ssh_command_with_user_override():
117117
mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="")
118118

119119
# Test SSH command with -l flag overriding default username
120-
result = client.run(False, ["-l", "myuser", "hostname"])
120+
result = client.run(SSHCommandRunOptions(direct=False), ["-l", "myuser", "hostname"])
121121
assert isinstance(result, SSHCommandRunResult)
122122

123123
# Verify subprocess.run was called
@@ -155,7 +155,7 @@ def test_ssh_command_with_port():
155155
mock_adapter.return_value.__exit__.return_value = None
156156

157157
# Test SSH command with custom port
158-
result = client.run(False, ["hostname"])
158+
result = client.run(SSHCommandRunOptions(direct=False), ["hostname"])
159159
assert isinstance(result, SSHCommandRunResult)
160160

161161
# Verify subprocess.run was called
@@ -193,7 +193,7 @@ def test_ssh_command_with_direct_flag():
193193
# Mock the tcp.address() method
194194
with patch.object(client.tcp, 'address', return_value="tcp://192.168.1.100:22"):
195195
# Test SSH command with direct flag
196-
result = client.run(True, ["hostname"])
196+
result = client.run(SSHCommandRunOptions(direct=True), ["hostname"])
197197
assert isinstance(result, SSHCommandRunResult)
198198

199199
# Verify subprocess.run was called
@@ -224,7 +224,7 @@ def test_ssh_command_error_handling():
224224
mock_run.side_effect = FileNotFoundError("SSH not found")
225225

226226
# Test SSH command error handling
227-
result = client.run(False, ["hostname"])
227+
result = client.run(SSHCommandRunOptions(direct=False), ["hostname"])
228228
assert isinstance(result, SSHCommandRunResult)
229229

230230
# Should return error code 127
@@ -245,7 +245,7 @@ def test_ssh_command_with_multiple_ssh_options():
245245
mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="")
246246

247247
# Test SSH command with multiple SSH options
248-
result = client.run(False, [
248+
result = client.run(SSHCommandRunOptions(direct=False), [
249249
"-o", "StrictHostKeyChecking=no", "-i", "/path/to/key", "command", "arg1", "arg2"
250250
])
251251
assert isinstance(result, SSHCommandRunResult)
@@ -283,7 +283,7 @@ def test_ssh_command_with_unknown_option_treated_as_command():
283283
mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="")
284284

285285
# Test SSH command with unknown option
286-
result = client.run(False, ["-l", "user", "-unknown", "command", "arg1"])
286+
result = client.run(SSHCommandRunOptions(direct=False), ["-l", "user", "-unknown", "command", "arg1"])
287287
assert isinstance(result, SSHCommandRunResult)
288288

289289
# Verify subprocess.run was called
@@ -317,7 +317,7 @@ def test_ssh_command_with_no_ssh_options():
317317
mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="")
318318

319319
# Test SSH command with no SSH options
320-
result = client.run(False, ["command", "arg1", "arg2"])
320+
result = client.run(SSHCommandRunOptions(direct=False), ["command", "arg1", "arg2"])
321321
assert isinstance(result, SSHCommandRunResult)
322322

323323
# Verify subprocess.run was called
@@ -347,7 +347,7 @@ def test_ssh_command_with_command_l_flag_does_not_interfere_with_username_inject
347347
mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="")
348348

349349
# Test SSH command with -l flag in the command (like ls -la -l ajo)
350-
result = client.run(False, ["ls", "-la", "-l", "ajo"])
350+
result = client.run(SSHCommandRunOptions(direct=False), ["ls", "-la", "-l", "ajo"])
351351
assert isinstance(result, SSHCommandRunResult)
352352

353353
# Verify subprocess.run was called
@@ -469,7 +469,7 @@ def test_ssh_command_with_identity_string():
469469
mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="")
470470

471471
# Test SSH command with identity string
472-
result = client.run(False, ["hostname"])
472+
result = client.run(SSHCommandRunOptions(direct=False), ["hostname"])
473473
assert isinstance(result, SSHCommandRunResult)
474474

475475
# Verify subprocess.run was called
@@ -519,7 +519,7 @@ def test_ssh_command_with_identity_file():
519519
mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="")
520520

521521
# Test SSH command with identity file
522-
result = client.run(False, ["hostname"])
522+
result = client.run(SSHCommandRunOptions(direct=False), ["hostname"])
523523
assert isinstance(result, SSHCommandRunResult)
524524

525525
# Verify subprocess.run was called
@@ -563,7 +563,7 @@ def test_ssh_command_without_identity():
563563
mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="")
564564

565565
# Test SSH command without identity
566-
result = client.run(False, ["hostname"])
566+
result = client.run(SSHCommandRunOptions(direct=False), ["hostname"])
567567
assert isinstance(result, SSHCommandRunResult)
568568

569569
# Verify subprocess.run was called
@@ -608,7 +608,7 @@ def test_ssh_identity_temp_file_creation_and_cleanup():
608608
mock_temp_file.return_value = mock_temp_file_instance
609609

610610
# Test SSH command with identity
611-
result = client.run(False, ["hostname"])
611+
result = client.run(SSHCommandRunOptions(direct=False), ["hostname"])
612612
assert isinstance(result, SSHCommandRunResult)
613613

614614
# Verify temporary file was created
@@ -644,7 +644,7 @@ def test_ssh_identity_temp_file_creation_error():
644644
# Test SSH command with identity should raise an error
645645
# The exception will be wrapped in an ExceptionGroup due to the context manager
646646
with pytest.raises(ExceptionGroup) as exc_info:
647-
client.run(False, ["hostname"])
647+
client.run(SSHCommandRunOptions(direct=False), ["hostname"])
648648

649649
# Check that the original OSError is in the exception group
650650
assert any(isinstance(e, OSError) and "Permission denied" in str(e) for e in exc_info.value.exceptions)
@@ -677,7 +677,7 @@ def test_ssh_identity_temp_file_cleanup_error():
677677

678678
# Test SSH command with identity - should still succeed but log warning
679679
with patch.object(client, 'logger') as mock_logger:
680-
result = client.run(False, ["hostname"])
680+
result = client.run(SSHCommandRunOptions(direct=False), ["hostname"])
681681
assert isinstance(result, SSHCommandRunResult)
682682

683683
# Verify chmod was called

0 commit comments

Comments
 (0)