1+ import os
12import shlex
23import subprocess
4+ import tempfile
35from dataclasses import dataclass
46from urllib .parse import urlparse
57
@@ -49,6 +51,7 @@ def run(self, direct, args):
4951 # Get SSH command and default username from driver
5052 ssh_command = self .call ("get_ssh_command" )
5153 default_username = self .call ("get_default_username" )
54+ ssh_identity = self .call ("get_ssh_identity" )
5255
5356 if direct :
5457 # Use direct TCP address
@@ -60,7 +63,7 @@ def run(self, direct, args):
6063 if not host or not port :
6164 raise ValueError (f"Invalid address format: { address } " )
6265 self .logger .debug (f"Using direct TCP connection for SSH - host: { host } , port: { port } " )
63- return self ._run_ssh_local (host , port , ssh_command , default_username , args )
66+ return self ._run_ssh_local (host , port , ssh_command , default_username , ssh_identity , args )
6467 except (DriverMethodNotImplemented , ValueError ) as e :
6568 self .logger .error (f"Direct address connection failed ({ e } ), falling back to SSH port forwarding" )
6669 return self .run (False , args )
@@ -73,27 +76,61 @@ def run(self, direct, args):
7376 host = addr [0 ]
7477 port = addr [1 ]
7578 self .logger .debug (f"SSH port forward established - host: { host } , port: { port } " )
76- return self ._run_ssh_local (host , port , ssh_command , default_username , args )
79+ return self ._run_ssh_local (host , port , ssh_command , default_username , ssh_identity , args )
7780
78- def _run_ssh_local (self , host , port , ssh_command , default_username , args ):
81+ def _run_ssh_local (self , host , port , ssh_command , default_username , ssh_identity , args ):
7982 """Run SSH command with the given host, port, and arguments"""
80- # Build SSH command arguments
81- ssh_args = self ._build_ssh_command_args (ssh_command , port , default_username , args )
82-
83- # Separate SSH options from command arguments
84- ssh_options , command_args = self ._separate_ssh_options_and_command_args (args )
85-
86- # Build final SSH command
87- ssh_args = self ._build_final_ssh_command (ssh_args , ssh_options , host , command_args )
88-
89- # Execute the command
90- return self ._execute_ssh_command (ssh_args )
83+ # Create temporary identity file if needed
84+ identity_file = None
85+ temp_file = None
86+ if ssh_identity :
87+ try :
88+ temp_file = tempfile .NamedTemporaryFile (mode = 'w' , delete = False , suffix = '_ssh_key' )
89+ temp_file .write (ssh_identity )
90+ temp_file .close ()
91+ # Set proper permissions (600) for SSH key
92+ os .chmod (temp_file .name , 0o600 )
93+ identity_file = temp_file .name
94+ self .logger .debug (f"Created temporary identity file: { identity_file } " )
95+ except Exception as e :
96+ self .logger .error (f"Failed to create temporary identity file: { e } " )
97+ if temp_file :
98+ try :
99+ os .unlink (temp_file .name )
100+ except Exception :
101+ pass
102+ raise
91103
92- def _build_ssh_command_args (self , ssh_command , port , default_username , args ):
104+ try :
105+ # Build SSH command arguments
106+ ssh_args = self ._build_ssh_command_args (ssh_command , port , default_username , identity_file , args )
107+
108+ # Separate SSH options from command arguments
109+ ssh_options , command_args = self ._separate_ssh_options_and_command_args (args )
110+
111+ # Build final SSH command
112+ ssh_args = self ._build_final_ssh_command (ssh_args , ssh_options , host , command_args )
113+
114+ # Execute the command
115+ return self ._execute_ssh_command (ssh_args )
116+ finally :
117+ # Clean up temporary identity file
118+ if identity_file :
119+ try :
120+ os .unlink (identity_file )
121+ self .logger .debug (f"Cleaned up temporary identity file: { identity_file } " )
122+ except Exception as e :
123+ self .logger .warning (f"Failed to clean up temporary identity file { identity_file } : { e } " )
124+
125+ def _build_ssh_command_args (self , ssh_command , port , default_username , identity_file , args ):
93126 """Build initial SSH command arguments"""
94127 # Split the SSH command into individual arguments
95128 ssh_args = shlex .split (ssh_command )
96129
130+ # Add identity file if provided
131+ if identity_file :
132+ ssh_args .extend (["-i" , identity_file ])
133+
97134 # Add port if specified
98135 if port and port != 22 :
99136 ssh_args .extend (["-p" , str (port )])
0 commit comments