diff --git a/CHANGELOG.md b/CHANGELOG.md index cc005a7eb..98c5226bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ and this project adheres to [Calendar Versioning](https://calver.org)html). ## [Unreleased] ### Added ### Changed +* For SSH Connections: Reject unknown hosts, ask user if he wants to connect to unknown remote host and ask user if he wants to add the host to `known_hosts` [#486](https://github.com/greenbone/python-gvm/pull/486) + ### Deprecated ### Removed ### Fixed diff --git a/gvm/connections.py b/gvm/connections.py index 7d2c0d607..f217c2fd2 100644 --- a/gvm/connections.py +++ b/gvm/connections.py @@ -18,11 +18,16 @@ """ Module for connections to GVM server daemons like gvmd and ospd. """ +import base64 +import hashlib import logging import socket as socketlib import ssl +import sys import time + +from pathlib import Path from typing import Optional, Union import paramiko @@ -43,6 +48,7 @@ DEFAULT_SSH_USERNAME = "gmp" DEFAULT_SSH_PASSWORD = "" DEFAULT_HOSTNAME = '127.0.0.1' +DEFAULT_KNOWN_HOSTS_FILE = ".ssh/known_hosts" MAX_SSH_DATA_LENGTH = 4095 @@ -189,6 +195,7 @@ def __init__( port: Optional[int] = DEFAULT_SSH_PORT, username: Optional[str] = DEFAULT_SSH_USERNAME, password: Optional[str] = DEFAULT_SSH_PASSWORD, + known_hosts_file: Optional[str] = None, ): super().__init__(timeout=timeout) @@ -200,6 +207,11 @@ def __init__( self.password = ( password if password is not None else DEFAULT_SSH_PASSWORD ) + self.known_hosts_file = ( + Path(known_hosts_file) + if known_hosts_file is not None + else Path.home() / DEFAULT_KNOWN_HOSTS_FILE + ) def _send_all(self, data) -> int: """Returns the sum of sent bytes if success""" @@ -216,12 +228,124 @@ def _send_all(self, data) -> int: data = data[sent:] return sent_sum + def _ssh_authentication_input_loop( + self, hostkeys: paramiko.HostKeys, key: paramiko.PKey + ) -> None: + # Ask user for permission to continue + # let it look like openssh + sha64_fingerprint = base64.b64encode( + hashlib.sha256(base64.b64decode(key.get_base64())).digest() + ).decode("utf-8")[:-1] + key_type = key.get_name().replace('ssh-', '').upper() + print( + f"The authenticity of host '{self.hostname}' can't " + "be established." + ) + print(f"{key_type} key fingerprint is {sha64_fingerprint}.") + print('Are you sure you want to continue connecting (yes/no)? ', end='') + add = input() + while True: + if add == 'yes': + hostkeys.add(self.hostname, key.get_name(), key) + # ask user if the key should be added permanently + print( + f'Do you want to add {self.hostname} ' + 'to known_hosts (yes/no)? ', + end='', + ) + save = input() + while True: + if save == 'yes': + try: + hostkeys.save(filename=self.known_hosts_file) + except OSError as e: + raise GvmError( + 'Something went wrong with writing ' + f'the known_hosts file: {e}' + ) from None + logger.info( + "Warning: Permanently added '%s' (%s) to " + "the list of known hosts.", + self.hostname, + key_type, + ) + break + elif save == 'no': + logger.info( + "Warning: Host '%s' (%s) not added to " + "the list of known hosts.", + self.hostname, + key_type, + ) + break + else: + print("Please type 'yes' or 'no': ", end='') + save = input() + break + elif add == 'no': + return sys.exit( + 'User denied key. Host key verification failed.' + ) + else: + print("Please type 'yes' or 'no': ", end='') + add = input() + + def _get_remote_host_key(self): + """Get the remote host key for ssh connection""" + try: + tmp_socket = socketlib.socket() + tmp_socket.connect((self.hostname, 22)) + except OSError as e: + raise GvmError( + "Couldn't establish a connection to fetch the" + f" remote server key: {e}" + ) from None + + trans = paramiko.transport.Transport(tmp_socket) + try: + trans.start_client() + except paramiko.SSHException as e: + raise GvmError( + f"Couldn't fetch the remote server key: {e}" + ) from None + key = trans.get_remote_server_key() + try: + trans.close() + except paramiko.SSHException as e: + raise GvmError( + f"Couldn't close the connection to the remote server key: {e}" + ) from None + return key + + def _ssh_authentication(self) -> None: + """Search/add/save the servers key for the SSH authentication process""" + + # set to reject policy (avoid MITM attacks) + self._socket.set_missing_host_key_policy(paramiko.RejectPolicy()) + # openssh is posix, so this might only a posix approach + # https://stackoverflow.com/q/32945533 + try: + # load the keys into paramiko and check if remote is in the list + self._socket.load_host_keys(filename=self.known_hosts_file) + except OSError as e: + raise GvmError( + 'Something went wrong with reading ' + f'the known_hosts file: {e}' + ) from None + hostkeys = self._socket.get_host_keys() + if not hostkeys.lookup(self.hostname): + # Key not found, so connect to remote and fetch the key + # with the paramiko Transport protocol + key = self._get_remote_host_key() + + self._ssh_authentication_input_loop(hostkeys=hostkeys, key=key) + def connect(self) -> None: """ Connect to the SSH server and authenticate to it """ self._socket = paramiko.SSHClient() - self._socket.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + self._ssh_authentication() try: self._socket.connect( @@ -242,8 +366,9 @@ def connect(self) -> None: paramiko.BadHostKeyException, paramiko.AuthenticationException, paramiko.SSHException, + ConnectionError, ) as e: - raise GvmError("SSH Connection failed", e) from None + raise GvmError(f"SSH Connection failed: {e}") from None def _read(self) -> bytes: return self._stdout.channel.recv(BUF_SIZE) diff --git a/tests/connections/test_ssh_connection.py b/tests/connections/test_ssh_connection.py index df0c95299..c452b1d13 100644 --- a/tests/connections/test_ssh_connection.py +++ b/tests/connections/test_ssh_connection.py @@ -16,20 +16,36 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from io import StringIO import unittest from unittest.mock import patch, Mock +from pathlib import Path +import paramiko from gvm.connections import ( SSHConnection, DEFAULT_SSH_PORT, DEFAULT_SSH_USERNAME, DEFAULT_SSH_PASSWORD, DEFAULT_HOSTNAME, + DEFAULT_KNOWN_HOSTS_FILE, ) from gvm.errors import GvmError class SSHConnectionTestCase(unittest.TestCase): # pylint: disable=protected-access, invalid-name + def setUp(self): + self.known_hosts_file = Path('known_hosts') + with self.known_hosts_file.open("a") as fp: + fp.write( + '127.0.0.1 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBOZWi' + 'fs+DoMqIa5Nr0wiVrzQNpMbUwaLzuSTN6rNrYA\n' + ) + + def tearDown(self): + if self.known_hosts_file.exists(): + self.known_hosts_file.unlink() + def test_init_no_args(self): ssh_connection = SSHConnection() @@ -48,9 +64,15 @@ def check_ssh_connection_for_default_values(self, ssh_connection): self.assertEqual(ssh_connection.port, DEFAULT_SSH_PORT) self.assertEqual(ssh_connection.username, DEFAULT_SSH_USERNAME) self.assertEqual(ssh_connection.password, DEFAULT_SSH_PASSWORD) + self.assertEqual( + ssh_connection.known_hosts_file, + Path.home() / DEFAULT_KNOWN_HOSTS_FILE, + ) def test_connect_error(self): - ssh_connection = SSHConnection() + print(self.known_hosts_file.read_text()) + + ssh_connection = SSHConnection(known_hosts_file=self.known_hosts_file) with self.assertRaises(GvmError, msg="SSH Connection failed"): ssh_connection.connect() @@ -58,18 +80,216 @@ def test_connect(self): with patch('paramiko.SSHClient') as SSHClientMock: client_mock = SSHClientMock.return_value client_mock.exec_command.return_value = ['a', 'b', 'c'] - ssh_connection = SSHConnection() + ssh_connection = SSHConnection( + known_hosts_file=self.known_hosts_file + ) ssh_connection.connect() self.assertEqual(ssh_connection._stdin, 'a') self.assertEqual(ssh_connection._stdout, 'b') self.assertEqual(ssh_connection._stderr, 'c') + ssh_connection.disconnect() + + def test_connect_unknown_host(self): + ssh_connection = SSHConnection( + hostname='0.0.0.1', known_hosts_file=self.known_hosts_file + ) + with self.assertRaises( + GvmError, + msg=( + "Could'nt establish a connection to fetch the remote " + "server key: [Errno 65] No route to host" + ), + ): + ssh_connection.connect() + + @patch('builtins.input') + def test_connect_adding_and_save_hostkey(self, input_mock): + + key_io = StringIO( + """-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW +QyNTUxOQAAACB69SvZKJh/9VgSL0G27b5xVYa8nethH3IERbi0YqJDXwAAAKhjwAdrY8AH +awAAAAtzc2gtZWQyNTUxOQAAACB69SvZKJh/9VgSL0G27b5xVYa8nethH3IERbi0YqJDXw +AAAEA9tGQi2IrprbOSbDCF+RmAHd6meNSXBUQ2ekKXm4/8xnr1K9komH/1WBIvQbbtvnFV +hryd62EfcgRFuLRiokNfAAAAI2FsZXhfZ2F5bm9yQEFsZXhzLU1hY0Jvb2stQWlyLmxvY2 +FsAQI= + -----END OPENSSH PRIVATE KEY-----""" + ) + key = paramiko.Ed25519Key.from_private_key(key_io) + key_type = key.get_name().replace('ssh-', '').upper() + hostname = '0.0.0.0' + input_mock.side_effect = ['yes', 'yes'] + ssh_connection = SSHConnection( + hostname=hostname, known_hosts_file=self.known_hosts_file + ) + ssh_connection._socket = paramiko.SSHClient() + keys = self.known_hosts_file.read_text() + self.assertEqual( + keys, + '127.0.0.1 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBOZWi' + 'fs+DoMqIa5Nr0wiVrzQNpMbUwaLzuSTN6rNrYA\n', + ) + + with self.assertLogs('gvm.connections', level='INFO') as cm: + hostkeys = paramiko.HostKeys(filename=self.known_hosts_file) + ssh_connection._ssh_authentication_input_loop( + hostkeys=hostkeys, key=key + ) + keys = self.known_hosts_file.read_text() + + self.assertEqual( + cm.output, + [ + "INFO:gvm.connections:Warning: " + f"Permanently added '{hostname}' ({key_type}) to " + "the list of known hosts." + ], + ) + self.assertEqual( + keys, + '127.0.0.1 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBOZWi' + 'fs+DoMqIa5Nr0wiVrzQNpMbUwaLzuSTN6rNrYA\n' + f'0.0.0.0 {key.get_name()} {key.get_base64()}\n', + ) + + @patch('builtins.input') + def test_connect_adding_and_dont_save_hostkey(self, input_mock): + + key_io = StringIO( + """-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW +QyNTUxOQAAACB69SvZKJh/9VgSL0G27b5xVYa8nethH3IERbi0YqJDXwAAAKhjwAdrY8AH +awAAAAtzc2gtZWQyNTUxOQAAACB69SvZKJh/9VgSL0G27b5xVYa8nethH3IERbi0YqJDXw +AAAEA9tGQi2IrprbOSbDCF+RmAHd6meNSXBUQ2ekKXm4/8xnr1K9komH/1WBIvQbbtvnFV +hryd62EfcgRFuLRiokNfAAAAI2FsZXhfZ2F5bm9yQEFsZXhzLU1hY0Jvb2stQWlyLmxvY2 +FsAQI= + -----END OPENSSH PRIVATE KEY-----""" + ) + key = paramiko.Ed25519Key.from_private_key(key_io) + key_type = key.get_name().replace('ssh-', '').upper() + hostname = '0.0.0.0' + input_mock.side_effect = ['yes', 'no'] + ssh_connection = SSHConnection( + hostname=hostname, known_hosts_file=self.known_hosts_file + ) + ssh_connection._socket = paramiko.SSHClient() + keys = self.known_hosts_file.read_text() + self.assertEqual( + keys, + '127.0.0.1 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBOZWi' + 'fs+DoMqIa5Nr0wiVrzQNpMbUwaLzuSTN6rNrYA\n', + ) + + with self.assertLogs('gvm.connections', level='INFO') as cm: + hostkeys = paramiko.HostKeys(filename=self.known_hosts_file) + ssh_connection._ssh_authentication_input_loop( + hostkeys=hostkeys, key=key + ) + keys = self.known_hosts_file.read_text() + + self.assertEqual( + cm.output, + [ + "INFO:gvm.connections:Warning: " + f"Host '{hostname}' ({key_type}) not added to " + "the list of known hosts." + ], + ) + + self.assertEqual( + keys, + '127.0.0.1 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBOZWi' + 'fs+DoMqIa5Nr0wiVrzQNpMbUwaLzuSTN6rNrYA\n', + ) + + @patch('builtins.input') + @patch('sys.stdout', new_callable=StringIO) + def test_connect_wrong_input(self, stdout_mock, input_mock): + + key_io = StringIO( + """-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW +QyNTUxOQAAACB69SvZKJh/9VgSL0G27b5xVYa8nethH3IERbi0YqJDXwAAAKhjwAdrY8AH +awAAAAtzc2gtZWQyNTUxOQAAACB69SvZKJh/9VgSL0G27b5xVYa8nethH3IERbi0YqJDXw +AAAEA9tGQi2IrprbOSbDCF+RmAHd6meNSXBUQ2ekKXm4/8xnr1K9komH/1WBIvQbbtvnFV +hryd62EfcgRFuLRiokNfAAAAI2FsZXhfZ2F5bm9yQEFsZXhzLU1hY0Jvb2stQWlyLmxvY2 +FsAQI= + -----END OPENSSH PRIVATE KEY-----""" + ) + key = paramiko.Ed25519Key.from_private_key(key_io) + hostname = '0.0.0.0' + key_type = key.get_name().replace('ssh-', '').upper() + inputs = ['asd', 'yes', 'yoo', 'no'] + input_mock.side_effect = inputs + ssh_connection = SSHConnection( + hostname=hostname, known_hosts_file=self.known_hosts_file + ) + ssh_connection._socket = paramiko.SSHClient() + + with self.assertLogs('gvm.connections', level='INFO') as cm: + hostkeys = paramiko.HostKeys(filename=self.known_hosts_file) + ssh_connection._ssh_authentication_input_loop( + hostkeys=hostkeys, key=key + ) + ret = stdout_mock.getvalue() + + self.assertEqual( + cm.output, + [ + "INFO:gvm.connections:Warning: " + f"Host '{hostname}' ({key_type}) not added to " + "the list of known hosts." + ], + ) + + self.assertEqual( + ret, + f"The authenticity of host '{hostname}' can't be established.\n" + f"{key_type} key fingerprint is " + "J6VESFdD3xSChn8y9PzWzeF+1tl892mOy2TqkMLO4ow.\n" + "Are you sure you want to continue connecting (yes/no)? " + "Please type 'yes' or 'no': " + "Do you want to add 0.0.0.0 to known_hosts (yes/no)? " + "Please type 'yes' or 'no': ", + ) + + @patch('builtins.input') + def test_user_denies_auth(self, input_mock): + + key_io = StringIO( + """-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW +QyNTUxOQAAACB69SvZKJh/9VgSL0G27b5xVYa8nethH3IERbi0YqJDXwAAAKhjwAdrY8AH +awAAAAtzc2gtZWQyNTUxOQAAACB69SvZKJh/9VgSL0G27b5xVYa8nethH3IERbi0YqJDXw +AAAEA9tGQi2IrprbOSbDCF+RmAHd6meNSXBUQ2ekKXm4/8xnr1K9komH/1WBIvQbbtvnFV +hryd62EfcgRFuLRiokNfAAAAI2FsZXhfZ2F5bm9yQEFsZXhzLU1hY0Jvb2stQWlyLmxvY2 +FsAQI= + -----END OPENSSH PRIVATE KEY-----""" + ) + key = paramiko.Ed25519Key.from_private_key(key_io) + hostname = '0.0.0.0' + input_mock.return_value = 'no' + ssh_connection = SSHConnection( + hostname=hostname, known_hosts_file=self.known_hosts_file + ) + ssh_connection._socket = paramiko.SSHClient() + + with self.assertRaises( + SystemExit, msg='User denied key. Host key verification failed.' + ): + hostkeys = paramiko.HostKeys(filename=self.known_hosts_file) + ssh_connection._ssh_authentication_input_loop( + hostkeys=hostkeys, key=key + ) def test_disconnect(self): with patch('paramiko.SSHClient') as SSHClientMock: client_mock = SSHClientMock.return_value client_mock.exec_command.return_value = ['a', 'b', 'c'] - ssh_connection = SSHConnection() + ssh_connection = SSHConnection( + known_hosts_file=self.known_hosts_file + ) ssh_connection.connect() self.assertEqual(ssh_connection._stdin, 'a') @@ -88,7 +308,7 @@ def test_disconnect(self): type(ssh_connection._socket) with self.assertRaises(AttributeError): - with self.assertLogs('foo', level='INFO') as cm: + with self.assertLogs('gvm.connections', level='INFO') as cm: # disconnect twice should not work ... ssh_connection.disconnect() self.assertEqual( @@ -108,25 +328,77 @@ def test_disconnect_os_error(self): client_mock.exec_command.return_value = ['a', 'b', 'c'] client_mock.close.side_effect = OSError - ssh_connection = SSHConnection() + ssh_connection = SSHConnection( + known_hosts_file=self.known_hosts_file + ) ssh_connection.connect() with self.assertRaises(OSError): - with self.assertLogs('foo', level='INFO') as cm: + with self.assertLogs('gvm.connections', level='INFO') as cm: ssh_connection.disconnect() self.assertEqual(cm.output, ['Connection closing error: ']) + def test_trigger_paramiko_ssh_except_in_get_remote_key(self): + with patch('paramiko.transport.Transport') as TransportMock: + client_mock = TransportMock.return_value + client_mock.start_client.side_effect = paramiko.SSHException('foo') + + ssh_connection = SSHConnection( + hostname="0.0.0.0", + ) + + with self.assertRaises( + GvmError, + msg="Couldn't fetch the remote server key: foo", + ): + ssh_connection._get_remote_host_key() + + def test_trigger_oserror_in_get_remote_key_connect(self): + with patch('socket.socket') as SocketMock: + client_mock = SocketMock.return_value + client_mock.connect.side_effect = OSError('foo') + + ssh_connection = SSHConnection( + hostname="0.0.0.0", + ) + + with self.assertRaises( + GvmError, + msg="Couldn't establish a connection to fetch the" + "remote server key: foo", + ): + ssh_connection._get_remote_host_key() + + def test_trigger_oserror_in_get_remote_key_disconnect(self): + with patch('paramiko.transport.Transport') as TransportMock: + client_mock = TransportMock.return_value + client_mock.close.side_effect = paramiko.SSHException('foo') + + ssh_connection = SSHConnection( + hostname="0.0.0.0", + ) + + with self.assertRaises( + GvmError, + msg="Couldn't close the connection to the" + "remote server key: foo", + ): + ssh_connection._get_remote_host_key() + def test_send(self): with patch('paramiko.SSHClient') as SSHClientMock: client_mock = SSHClientMock.return_value stdin = Mock() stdin.channel.send.return_value = 4 client_mock.exec_command.return_value = [stdin, None, None] - ssh_connection = SSHConnection() + ssh_connection = SSHConnection( + known_hosts_file=self.known_hosts_file + ) ssh_connection.connect() req = ssh_connection.send("blah") self.assertEqual(req, 4) + ssh_connection.disconnect() def test_send_error(self): with patch('paramiko.SSHClient') as SSHClientMock: @@ -134,13 +406,16 @@ def test_send_error(self): stdin = Mock() stdin.channel.send.return_value = None client_mock.exec_command.return_value = [stdin, None, None] - ssh_connection = SSHConnection() + ssh_connection = SSHConnection( + known_hosts_file=self.known_hosts_file + ) ssh_connection.connect() with self.assertRaises( GvmError, msg='Remote closed the connection' ): ssh_connection.send("blah") + ssh_connection.disconnect() def test_send_and_slice(self): with patch('paramiko.SSHClient') as SSHClientMock: @@ -148,7 +423,9 @@ def test_send_and_slice(self): stdin = Mock() stdin.channel.send.side_effect = [2, 2] client_mock.exec_command.return_value = [stdin, None, None] - ssh_connection = SSHConnection() + ssh_connection = SSHConnection( + known_hosts_file=self.known_hosts_file + ) ssh_connection.connect() req = ssh_connection.send("blah") @@ -157,6 +434,7 @@ def test_send_and_slice(self): stdin.channel.send.assert_called() with self.assertRaises(AssertionError): stdin.channel.send.assert_called_once() + ssh_connection.disconnect() def test_read(self): with patch('paramiko.SSHClient') as SSHClientMock: @@ -164,8 +442,11 @@ def test_read(self): stdout = Mock() stdout.channel.recv.return_value = b"foo bar baz" client_mock.exec_command.return_value = [None, stdout, None] - ssh_connection = SSHConnection() + ssh_connection = SSHConnection( + known_hosts_file=self.known_hosts_file + ) ssh_connection.connect() recved = ssh_connection._read() self.assertEqual(recved, b'foo bar baz') + ssh_connection.disconnect()