Skip to content

Commit

Permalink
Add: CI parameter to SSH connection for automatic skip of user input (#…
Browse files Browse the repository at this point in the history
…1009)

* Add: CI parameter to SSH connection for automatic skip of user input

* Refactor

* Code review
  • Loading branch information
y0urself authored Apr 26, 2023
1 parent e57a46c commit f02ea0d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 17 deletions.
58 changes: 41 additions & 17 deletions gvm/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,8 @@ def read(self) -> str:
if self._is_end_xml():
break

if self._timeout is not None:
now = time.time()

if now > break_timeout:
raise GvmError("Timeout while reading the response")
if time.time() > break_timeout:
raise GvmError("Timeout while reading the response")

return response

Expand Down Expand Up @@ -194,6 +191,7 @@ def __init__(
username: Optional[str] = DEFAULT_SSH_USERNAME,
password: Optional[str] = DEFAULT_SSH_PASSWORD,
known_hosts_file: Optional[str] = None,
auto_accept_host: Optional[bool] = None,
):
super().__init__(timeout=timeout)

Expand All @@ -210,6 +208,7 @@ def __init__(
if known_hosts_file is not None
else Path.home() / DEFAULT_KNOWN_HOSTS_FILE
)
self.auto_accept_host = auto_accept_host

def _send_all(self, data) -> int:
"""Returns the sum of sent bytes if success"""
Expand All @@ -226,6 +225,32 @@ def _send_all(self, data) -> int:
data = data[sent:]
return sent_sum

def _auto_accept_host(
self, hostkeys: paramiko.HostKeys, key: paramiko.PKey
) -> None:
if self.port == DEFAULT_SSH_PORT:
hostkeys.add(self.hostname, key.get_name(), key)
elif self.port != DEFAULT_SSH_PORT:
hostkeys.add(
"[" + self.hostname + "]:" + str(self.port),
key.get_name(),
key,
)
try:
hostkeys.save(filename=self.known_hosts_file)
except OSError as e:
raise GvmError(
"Something went wrong with writing "
f"the known_hosts file: {e}"
) from None
key_type = key.get_name().replace("ssh-", "").upper()
logger.info(
"Warning: Permanently added '%s' (%s) to "
"the list of known hosts.",
self.hostname,
key_type,
)

def _ssh_authentication_input_loop(
self, hostkeys: paramiko.HostKeys, key: paramiko.PKey
) -> None:
Expand Down Expand Up @@ -342,18 +367,17 @@ def _ssh_authentication(self) -> None:
hostkeys = self._socket.get_host_keys()
# Switch based on SSH Port
if self.port == DEFAULT_SSH_PORT:
if not hostkeys.lookup(self.hostname):
# Key not found, so connect to remote and fetch the key
# with the paramiko Transport protocol
key = self._get_remote_host_key()

self._ssh_authentication_input_loop(hostkeys=hostkeys, key=key)
elif self.port != DEFAULT_SSH_PORT:
if not hostkeys.lookup("[" + self.hostname + "]:" + str(self.port)):
# Key not found, so connect to remote and fetch the key
# with the paramiko Transport protocol
key = self._get_remote_host_key()

hostname = self.hostname
else:
hostname = f"[{self.hostname}]:{self.port}"

if not hostkeys.lookup(hostname):
# Key not found, so connect to remote and fetch the key
# with the paramiko Transport protocol
key = self._get_remote_host_key()
if self.auto_accept_host:
self._auto_accept_host(hostkeys=hostkeys, key=key)
else:
self._ssh_authentication_input_loop(hostkeys=hostkeys, key=key)

def connect(self) -> None:
Expand Down
21 changes: 21 additions & 0 deletions tests/connections/test_ssh_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ def test_connect_error(self):
with self.assertRaises(GvmError, msg="SSH Connection failed"):
ssh_connection.connect()

def test_connect_error_auto_accept_host(self):
ssh_connection = SSHConnection(
known_hosts_file=self.known_hosts_file, auto_accept_host=True
)
with self.assertRaises(GvmError, msg="SSH Connection failed"):
ssh_connection.connect()

def test_connect(self):
with patch("paramiko.SSHClient") as SSHClientMock:
client_mock = SSHClientMock.return_value
Expand All @@ -91,6 +98,20 @@ def test_connect(self):
self.assertEqual(ssh_connection._stderr, "c")
ssh_connection.disconnect()

def test_connect_auto_accept_host(self):
with patch("paramiko.SSHClient") as SSHClientMock:
client_mock = SSHClientMock.return_value
client_mock.exec_command.return_value = ["a", "b", "c"]
ssh_connection = SSHConnection(
known_hosts_file=self.known_hosts_file, auto_accept_host=True
)

ssh_connection.connect()
self.assertEqual(ssh_connection._stdin, "a")
self.assertEqual(ssh_connection._stdout, "b")
self.assertEqual(ssh_connection._stderr, "c")
ssh_connection.disconnect()

def test_connect_unknown_host(self):
ssh_connection = SSHConnection(
hostname="0.0.0.1",
Expand Down

0 comments on commit f02ea0d

Please sign in to comment.