1+ import os
12import shlex
23import subprocess
4+ import tempfile
35from dataclasses import dataclass
46from urllib .parse import urlparse
57
@@ -45,6 +47,7 @@ def run(self, direct, args):
4547 # Get SSH command and default username from driver
4648 ssh_command = self .call ("get_ssh_command" )
4749 default_username = self .call ("get_default_username" )
50+ ssh_identity = self .call ("get_ssh_identity" )
4851
4952 if direct :
5053 # Use direct TCP address
@@ -56,7 +59,7 @@ def run(self, direct, args):
5659 if not host or not port :
5760 raise ValueError (f"Invalid address format: { address } " )
5861 self .logger .debug (f"Using direct TCP connection for SSH - host: { host } , port: { port } " )
59- return self ._run_ssh_local (host , port , ssh_command , default_username , args )
62+ return self ._run_ssh_local (host , port , ssh_command , default_username , ssh_identity , args )
6063 except (DriverMethodNotImplemented , ValueError ) as e :
6164 self .logger .error (f"Direct address connection failed ({ e } ), falling back to SSH port forwarding" )
6265 return self .run (False , args )
@@ -69,27 +72,61 @@ def run(self, direct, args):
6972 host = addr [0 ]
7073 port = addr [1 ]
7174 self .logger .debug (f"SSH port forward established - host: { host } , port: { port } " )
72- return self ._run_ssh_local (host , port , ssh_command , default_username , args )
75+ return self ._run_ssh_local (host , port , ssh_command , default_username , ssh_identity , args )
7376
74- def _run_ssh_local (self , host , port , ssh_command , default_username , args ):
77+ def _run_ssh_local (self , host , port , ssh_command , default_username , ssh_identity , args ):
7578 """Run SSH command with the given host, port, and arguments"""
76- # Build SSH command arguments
77- ssh_args = self ._build_ssh_command_args (ssh_command , port , default_username , args )
78-
79- # Separate SSH options from command arguments
80- ssh_options , command_args = self ._separate_ssh_options_and_command_args (args )
81-
82- # Build final SSH command
83- ssh_args = self ._build_final_ssh_command (ssh_args , ssh_options , host , command_args )
84-
85- # Execute the command
86- return self ._execute_ssh_command (ssh_args )
79+ # Create temporary identity file if needed
80+ identity_file = None
81+ temp_file = None
82+ if ssh_identity :
83+ try :
84+ temp_file = tempfile .NamedTemporaryFile (mode = 'w' , delete = False , suffix = '_ssh_key' )
85+ temp_file .write (ssh_identity )
86+ temp_file .close ()
87+ # Set proper permissions (600) for SSH key
88+ os .chmod (temp_file .name , 0o600 )
89+ identity_file = temp_file .name
90+ self .logger .debug (f"Created temporary identity file: { identity_file } " )
91+ except Exception as e :
92+ self .logger .error (f"Failed to create temporary identity file: { e } " )
93+ if temp_file :
94+ try :
95+ os .unlink (temp_file .name )
96+ except Exception :
97+ pass
98+ raise
8799
88- def _build_ssh_command_args (self , ssh_command , port , default_username , args ):
100+ try :
101+ # Build SSH command arguments
102+ ssh_args = self ._build_ssh_command_args (ssh_command , port , default_username , identity_file , args )
103+
104+ # Separate SSH options from command arguments
105+ ssh_options , command_args = self ._separate_ssh_options_and_command_args (args )
106+
107+ # Build final SSH command
108+ ssh_args = self ._build_final_ssh_command (ssh_args , ssh_options , host , command_args )
109+
110+ # Execute the command
111+ return self ._execute_ssh_command (ssh_args )
112+ finally :
113+ # Clean up temporary identity file
114+ if identity_file :
115+ try :
116+ os .unlink (identity_file )
117+ self .logger .debug (f"Cleaned up temporary identity file: { identity_file } " )
118+ except Exception as e :
119+ self .logger .warning (f"Failed to clean up temporary identity file { identity_file } : { e } " )
120+
121+ def _build_ssh_command_args (self , ssh_command , port , default_username , identity_file , args ):
89122 """Build initial SSH command arguments"""
90123 # Split the SSH command into individual arguments
91124 ssh_args = shlex .split (ssh_command )
92125
126+ # Add identity file if provided
127+ if identity_file :
128+ ssh_args .extend (["-i" , identity_file ])
129+
93130 # Add port if specified
94131 if port and port != 22 :
95132 ssh_args .extend (["-p" , str (port )])
0 commit comments