Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSH connection: Ask user before connection to unknown host/add host to known_hosts. #486

Merged
merged 13 commits into from
May 31, 2021
Merged
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ and this project adheres to [Calendar Versioning](https://calver.org)html).
## [Unreleased]
### Added
### Changed
* For SSH Connections: Reject unknown hosts, ask user if he wants to connect to unknown remote host and ask user if he wants to add the host to `known_hosts` [#486](https://github.com/greenbone/python-gvm/pull/486)

### Deprecated
### Removed
### Fixed
Expand Down
129 changes: 127 additions & 2 deletions gvm/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@
"""
Module for connections to GVM server daemons like gvmd and ospd.
"""
import base64
import hashlib
import logging
import socket as socketlib
import ssl
import sys
import time


from pathlib import Path
from typing import Optional, Union

import paramiko
Expand All @@ -43,6 +48,7 @@
DEFAULT_SSH_USERNAME = "gmp"
DEFAULT_SSH_PASSWORD = ""
DEFAULT_HOSTNAME = '127.0.0.1'
DEFAULT_KNOWN_HOSTS_FILE = ".ssh/known_hosts"
MAX_SSH_DATA_LENGTH = 4095


Expand Down Expand Up @@ -189,6 +195,7 @@ def __init__(
port: Optional[int] = DEFAULT_SSH_PORT,
username: Optional[str] = DEFAULT_SSH_USERNAME,
password: Optional[str] = DEFAULT_SSH_PASSWORD,
known_hosts_file: Optional[str] = None,
):
super().__init__(timeout=timeout)

Expand All @@ -200,6 +207,11 @@ def __init__(
self.password = (
password if password is not None else DEFAULT_SSH_PASSWORD
)
self.known_hosts_file = (
Path(known_hosts_file)
if known_hosts_file is not None
else Path.home() / DEFAULT_KNOWN_HOSTS_FILE
)

def _send_all(self, data) -> int:
"""Returns the sum of sent bytes if success"""
Expand All @@ -216,12 +228,124 @@ def _send_all(self, data) -> int:
data = data[sent:]
return sent_sum

def _ssh_authentication_input_loop(
self, hostkeys: paramiko.HostKeys, key: paramiko.PKey
) -> None:
# Ask user for permission to continue
# let it look like openssh
sha64_fingerprint = base64.b64encode(
hashlib.sha256(base64.b64decode(key.get_base64())).digest()
).decode("utf-8")[:-1]
key_type = key.get_name().replace('ssh-', '').upper()
print(
f"The authenticity of host '{self.hostname}' can't "
"be established."
)
print(f"{key_type} key fingerprint " f"is {sha64_fingerprint}.")
bjoernricks marked this conversation as resolved.
Show resolved Hide resolved
print('Are you sure you want to continue connecting (yes/no)? ', end='')
add = input()
while True:
if add == 'yes':
hostkeys.add(self.hostname, key.get_name(), key)
# ask user if the key should be added permanently
print(
f'Do you want to add {self.hostname} '
'to known_hosts (yes/no)? ',
end='',
)
save = input()
while True:
if save == 'yes':
try:
hostkeys.save(filename=self.known_hosts_file)
except OSError as e:
raise GvmError(
'Something went wrong with writing '
f'the known_hosts file: {e}'
) from None
logger.info(
"Warning: Permanently added '%s' (%s) to "
"the list of known hosts.",
self.hostname,
key_type,
)
break
elif save == 'no':
logger.info(
"Warning: Host '%s' (%s) not added to "
"the list of known hosts.",
self.hostname,
key_type,
)
break
else:
print("Please type 'yes' or 'no': ", end='')
save = input()
break
elif add == 'no':
return sys.exit(
'User denied key. Host key verification failed.'
)
else:
print("Please type 'yes' or 'no': ", end='')
add = input()

def _get_remote_host_key(self):
"""Get the remote host key for ssh connection"""
try:
tmp_socket = socketlib.socket()
tmp_socket.connect((self.hostname, 22))
except OSError as e:
raise GvmError(
"Couldn't establish a connection to fetch the"
f"remote server key: {e}"
bjoernricks marked this conversation as resolved.
Show resolved Hide resolved
) from None

trans = paramiko.transport.Transport(tmp_socket)
try:
trans.start_client()
except paramiko.SSHException as e:
raise GvmError(
"Couldn't fetch the" f"remote server key: {e}"
bjoernricks marked this conversation as resolved.
Show resolved Hide resolved
) from None
key = trans.get_remote_server_key()
try:
trans.close()
except paramiko.SSHException as e:
raise GvmError(
"Couldn't close the connection to the" f"remote server key: {e}"
bjoernricks marked this conversation as resolved.
Show resolved Hide resolved
) from None
return key

def _ssh_authentication(self) -> None:
"""Search/add/save the servers key for the SSH authentication process"""

# set to reject policy (avoid MITM attacks)
self._socket.set_missing_host_key_policy(paramiko.RejectPolicy())
# openssh is posix, so this might only a posix approach
# https://stackoverflow.com/q/32945533
try:
# load the keys into paramiko and check if remote is in the list
self._socket.load_host_keys(filename=self.known_hosts_file)
except OSError as e:
raise GvmError(
'Something went wrong with reading '
f'the known_hosts file: {e}'
) from None
hostkeys = self._socket.get_host_keys()
if not hostkeys.lookup(self.hostname):
# Key not found, so connect to remote and fetch the key
# with the paramiko Transport protocol
key = self._get_remote_host_key()

self._ssh_authentication_input_loop(hostkeys=hostkeys, key=key)

def connect(self) -> None:
"""
Connect to the SSH server and authenticate to it
"""
self._socket = paramiko.SSHClient()
self._socket.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self._ssh_authentication()

try:
self._socket.connect(
Expand All @@ -242,8 +366,9 @@ def connect(self) -> None:
paramiko.BadHostKeyException,
paramiko.AuthenticationException,
paramiko.SSHException,
ConnectionError,
) as e:
raise GvmError("SSH Connection failed", e) from None
raise GvmError(f"SSH Connection failed: {e}") from None

def _read(self) -> bytes:
return self._stdout.channel.recv(BUF_SIZE)
Expand Down
Loading