Skip to content

Commit bc8fe62

Browse files
committed
Allow specifying missing_host_key_policy.
Properly silence bandit on XML: defusedxml is basically dead
1 parent aab0d26 commit bc8fe62

2 files changed

Lines changed: 11 additions & 5 deletions

File tree

exec_helpers/_ssh_base.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,8 @@ class SSHClientBase(api.ExecHelper):
457457
:type keepalive: int | bool
458458
:param allow_ssh_agent: Use SSH Agent if available.
459459
:type allow_ssh_agent: bool
460+
:param missing_host_key_policy: Missing host key policy class. See paramiko.MissingHostKeyPolicy for details.
461+
:type missing_host_key_policy: type[paramiko.MissingHostKeyPolicy]
460462
461463
.. note:: auth has priority over username/password/private_keys.
462464
.. note::
@@ -475,6 +477,7 @@ class SSHClientBase(api.ExecHelper):
475477
.. versionchanged:: 7.0.0 keepalive_mode is removed.
476478
.. versionchanged:: 7.4.0 Return of keepalive_mode to prevent mix with a keepalive period. Default is `False`.
477479
.. versionchanged:: 8.0.0 Expose SSH Agent usage override.
480+
.. versionchanged:: 8.1.4 Allow specifying missing_host_key_policy.
478481
"""
479482

480483
__slots__ = (
@@ -491,6 +494,7 @@ class SSHClientBase(api.ExecHelper):
491494
"__ssh_config",
492495
"__sudo_mode",
493496
"__verbose",
497+
"_missing_host_key_policy",
494498
)
495499

496500
def __hash__(self) -> int:
@@ -515,13 +519,15 @@ def __init__(
515519
sock: paramiko.ProxyCommand | paramiko.Channel | socket.socket | None = None,
516520
keepalive: KeepAlivePeriodT = 1,
517521
allow_ssh_agent: bool = True,
522+
missing_host_key_policy: type[paramiko.MissingHostKeyPolicy] = paramiko.WarningPolicy,
518523
) -> None:
519524
"""Main SSH Client helper."""
520525
self.__sudo_mode = False
521526
self.__keepalive_period: int = int(keepalive)
522527
self.__keepalive_mode = False
523528
self.__verbose: bool = verbose
524529
self.__sock = sock
530+
self._missing_host_key_policy = missing_host_key_policy
525531

526532
self.__ssh: paramiko.SSHClient
527533
self.__sftp: paramiko.SFTPClient | None = None
@@ -755,7 +761,7 @@ def __connect(self) -> None:
755761
with self.lock:
756762
if self.__sock is not None:
757763
self.__ssh = paramiko.SSHClient()
758-
self.__ssh.set_missing_host_key_policy(paramiko.WarningPolicy())
764+
self.__ssh.set_missing_host_key_policy(self._missing_host_key_policy())
759765
self.auth.connect(
760766
client=self.__ssh,
761767
hostname=self.hostname,
@@ -781,7 +787,7 @@ def __get_client(self) -> paramiko.SSHClient:
781787
"""
782788

783789
last_ssh_client: paramiko.SSHClient = paramiko.SSHClient()
784-
last_ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy()) # noqa: S507,RUF100
790+
last_ssh_client.set_missing_host_key_policy(self._missing_host_key_policy())
785791

786792
config, auth = self.__conn_chain[0]
787793

@@ -795,7 +801,7 @@ def __get_client(self) -> paramiko.SSHClient:
795801

796802
for config, auth in self.__conn_chain[1:]: # start has another logic, so do it out of cycle
797803
ssh = paramiko.SSHClient()
798-
ssh.set_missing_host_key_policy(paramiko.WarningPolicy()) # noqa: S507,RUF100
804+
ssh.set_missing_host_key_policy(self._missing_host_key_policy())
799805

800806
if config.proxyjump:
801807
transport = last_ssh_client.get_transport()

exec_helpers/exec_result.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ def stdout_xml(self) -> xml.etree.ElementTree.Element: # type: ignore[name-defi
678678
:raises DeserializeValueError: STDOUT cannot be deserialized as XML.
679679
"""
680680
with self.stdout_lock:
681-
return xml.etree.ElementTree.fromstring(b"".join(self.stdout)) # type: ignore[attr-defined]
681+
return xml.etree.ElementTree.fromstring(b"".join(self.stdout)) # type: ignore[attr-defined] # nosec[B314]
682682

683683
if lxml is not None:
684684

@@ -694,7 +694,7 @@ def stdout_lxml(self) -> lxml.etree.Element:
694694
.. note:: Can be insecure.
695695
"""
696696
with self.stdout_lock:
697-
return lxml.etree.fromstring(b"".join(self.stdout)) # nosec[blacklist]
697+
return lxml.etree.fromstring(b"".join(self.stdout)) # nosec[B314]
698698

699699
def __dir__(self) -> list[str]:
700700
"""Override dir for IDE and as source for getitem checks.

0 commit comments

Comments
 (0)