diff --git a/src/ssh/HISTORY.md b/src/ssh/HISTORY.md index 4ecef1cb840..6e51f477016 100644 --- a/src/ssh/HISTORY.md +++ b/src/ssh/HISTORY.md @@ -1,5 +1,9 @@ Release History =============== +1.1.3 +----- +* [bug fix] SSH Banners are printed before authentication. + 1.1.2 ----- * Remove dependency to cryptography (Az CLI core alredy has cryptography) diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index e396fcefcaa..234ee452b96 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -190,7 +190,7 @@ def _do_ssh_op(cmd, op_info, op_call): op_info.private_key_file + ', ' if delete_keys else "", op_info.public_key_file + ', ' if delete_keys else "", op_info.cert_file if delete_cert else "") - ssh_utils.do_cleanup(delete_keys, delete_cert, op_info.cert_file, + ssh_utils.do_cleanup(delete_keys, delete_cert, op_info.delete_credentials, op_info.cert_file, op_info.private_key_file, op_info.public_key_file) raise e diff --git a/src/ssh/azext_ssh/rdp_utils.py b/src/ssh/azext_ssh/rdp_utils.py index e5387b0a852..7e1960d05f7 100644 --- a/src/ssh/azext_ssh/rdp_utils.py +++ b/src/ssh/azext_ssh/rdp_utils.py @@ -43,8 +43,8 @@ def start_rdp_connection(ssh_info, delete_keys, delete_cert): ssh_process, print_ssh_logs = start_ssh_tunnel(ssh_info) ssh_connection_t0 = time.time() ssh_success, log_list = wait_for_ssh_connection(ssh_process, print_ssh_logs) - ssh_utils.do_cleanup(delete_keys, delete_cert, ssh_info.cert_file, ssh_info.private_key_file, - ssh_info.public_key_file) + ssh_utils.do_cleanup(delete_keys, delete_cert, ssh_info.delete_credentials, ssh_info.cert_file, + ssh_info.private_key_file, ssh_info.public_key_file) if ssh_success and ssh_process.poll() is None: call_rdp(local_port) @@ -56,8 +56,8 @@ def start_rdp_connection(ssh_info, delete_keys, delete_cert): telemetry.add_extension_event('ssh', ssh_connection_data) terminate_ssh(ssh_process, log_list, print_ssh_logs) - ssh_utils.do_cleanup(delete_keys, delete_cert, ssh_info.cert_file, ssh_info.private_key_file, - ssh_info.public_key_file) + ssh_utils.do_cleanup(delete_keys, delete_cert, ssh_info.delete_credentials, ssh_info.cert_file, + ssh_info.private_key_file, ssh_info.public_key_file) if delete_keys: # This is only true if keys were generated, so they must be in a temp folder. temp_dir = os.path.dirname(ssh_info.cert_file) diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index 56633acba62..b889ee7cbc6 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -5,10 +5,10 @@ import os import platform import subprocess -import multiprocessing as mp import time import datetime import re +import sys import colorama from knack import log @@ -25,16 +25,23 @@ def start_ssh_connection(op_info, delete_keys, delete_cert): try: - # Initialize these so that if something fails in the try block before these - # are initialized, then the finally block won't fail. - cleanup_process = None - log_file = None - connection_status = None - ssh_arg_list = [] if op_info.ssh_args: ssh_arg_list = op_info.ssh_args + # Redirecting stderr: + # 1. Read SSH logs to determine if authentication was successful so credentials can be deleted + # 2. Read SSHProxy error messages to print friendly error messages for well known errors. + # On Linux when connecting to a local user on a host with a banner, output gets messed up if stderr redirected. + # If user expects logs to be printed, do not redirect logs. In some ocasions output gets messed up. + is_local_user_on_linux = (platform.system() != 'Windows' and not delete_cert) + redirect_stderr = set(['-v', '-vv', '-vvv']).isdisjoint(ssh_arg_list) and \ + (op_info.is_arc or delete_cert or op_info.delete_credentials) and \ + not is_local_user_on_linux + + if redirect_stderr: + ssh_arg_list = ['-v'] + ssh_arg_list + env = os.environ.copy() if op_info.is_arc(): env['SSHPROXY_RELAY_INFO'] = connectivity_utils.format_relay_info_string(op_info.relay_info) @@ -42,26 +49,19 @@ def start_ssh_connection(op_info, delete_keys, delete_cert): # Get ssh client before starting the clean up process in case there is an error in getting client. command = [get_ssh_client_path('ssh', op_info.ssh_client_folder), op_info.get_host(), "-l", op_info.local_user] - if not op_info.cert_file and not op_info.private_key_file: - # In this case, even if delete_credentials is true, there is nothing to clean-up. - op_info.delete_credentials = False - - log_file, ssh_arg_list, cleanup_process = _start_cleanup(op_info.cert_file, op_info.private_key_file, - op_info.public_key_file, op_info.delete_credentials, - delete_keys, delete_cert, ssh_arg_list) command = command + op_info.build_args() + ssh_arg_list connection_duration = time.time() logger.debug("Running ssh command %s", ' '.join(command)) - # pylint: disable=subprocess-run-check try: - if set(['-v', '-vv', '-vvv']).isdisjoint(ssh_arg_list) or log_file: - connection_status = subprocess.run(command, shell=platform.system() == 'Windows', env=env, - stderr=subprocess.PIPE, encoding='utf-8') + # pylint: disable=consider-using-with + if redirect_stderr: + ssh_process = subprocess.Popen(command, stderr=subprocess.PIPE, env=env, encoding='utf-8') + _read_ssh_logs(ssh_process, op_info, delete_cert, delete_keys) else: - # Logs are sent to stderr. In that case, we shouldn't capture stderr. - connection_status = subprocess.run(command, shell=platform.system() == 'Windows', env=env) + ssh_process = subprocess.Popen(command, env=env, encoding='utf-8') + _wait_to_delete_credentials(ssh_process, op_info, delete_cert, delete_keys) except OSError as e: colorama.init() raise azclierror.BadRequestError(f"Failed to run ssh command with error: {str(e)}.", @@ -69,15 +69,15 @@ def start_ssh_connection(op_info, delete_keys, delete_cert): connection_duration = (time.time() - connection_duration) / 60 ssh_connection_data = {'Context.Default.AzureCLI.SSHConnectionDurationInMinutes': connection_duration} - if connection_status and connection_status.returncode == 0: + if ssh_process.poll() == 0: ssh_connection_data['Context.Default.AzureCLI.SSHConnectionStatus'] = "Success" telemetry.add_extension_event('ssh', ssh_connection_data) finally: # Even if something fails between the creation of the credentials and the end of the ssh connection, we - # want to make sure that all credentials are cleaned up, and that the clean up process is terminated. - _terminate_cleanup(delete_keys, delete_cert, op_info.delete_credentials, cleanup_process, op_info.cert_file, - op_info.private_key_file, op_info.public_key_file, log_file, connection_status) + # want to make sure that all credentials are cleaned up. + do_cleanup(delete_keys, delete_cert, op_info.delete_credentials, + op_info.cert_file, op_info.private_key_file, op_info.public_key_file) def write_ssh_config(config_info, delete_keys, delete_cert): @@ -94,6 +94,51 @@ def write_ssh_config(config_info, delete_keys, delete_cert): f.write('\n'.join(config_text)) +def _read_ssh_logs(ssh_sub, op_info, delete_cert, delete_keys): + log_list = [] + connection_established = False + t0 = time.time() + + next_line = ssh_sub.stderr.readline() + while next_line: + log_list.append(next_line) + if not next_line.startswith("debug1:") and \ + not next_line.startswith("debug2:") and \ + not next_line.startswith("debug3:") and \ + not next_line.startswith("Authenticated "): + sys.stderr.write(next_line) + _check_for_known_errors(next_line, delete_cert, log_list) + + if "debug1: Entering interactive session." in next_line: + connection_established = True + do_cleanup(delete_keys, delete_cert, op_info.delete_credentials, + op_info.cert_file, op_info.private_key_file, op_info.public_key_file) + + if not connection_established and \ + time.time() - t0 > const.CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS: + do_cleanup(delete_keys, delete_cert, op_info.delete_credentials, + op_info.cert_file, op_info.private_key_file, op_info.public_key_file) + + next_line = ssh_sub.stderr.readline() + + ssh_sub.wait() + + +def _wait_to_delete_credentials(ssh_sub, op_info, delete_cert, delete_keys): + # wait for 2 minutes. If the process isn't closed until then, delete credentials. + if delete_cert or op_info.delete_credentials: + t0 = time.time() + while (time.time() - t0) < const.CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS: + if ssh_sub.poll() is not None: + break + time.sleep(1) + + do_cleanup(delete_keys, delete_cert, op_info.delete_credentials, + op_info.cert_file, op_info.private_key_file, op_info.public_key_file) + + ssh_sub.wait() + + def create_ssh_keyfile(private_key_file, ssh_client_folder=None): sshkeygen_path = get_ssh_client_path("ssh-keygen", ssh_client_folder) command = [sshkeygen_path, "-f", private_key_file, "-t", "rsa", "-q", "-N", ""] @@ -162,51 +207,44 @@ def get_ssh_cert_principals(cert_file, ssh_client_folder=None): return principals -def _print_error_messages_from_ssh_log(log_file, connection_status, delete_cert): - with open(log_file, 'r', encoding='utf-8') as ssh_log: - log_text = ssh_log.read() - log_lines = log_text.splitlines() - if ("debug1: Authentication succeeded" not in log_text and - not re.search("^Authenticated to .*\n", log_text, re.MULTILINE)) \ - or (connection_status and connection_status.returncode): - for line in log_lines: - if "debug1:" not in line: - print(line) - - # This connection fails when using our generated certificates. - # Only throw error if conection fails with AAD login. - if "Permission denied (publickey)." in log_text and delete_cert: - # pylint: disable=bare-except - # pylint: disable=too-many-boolean-expressions - # Check if OpenSSH client and server versions are incompatible - try: - regex = 'OpenSSH.*_([0-9]+)\\.([0-9]+)' - local_major, local_minor = re.findall(regex, log_lines[0])[0] - remote_major, remote_minor = re.findall(regex, - file_utils.get_line_that_contains("remote software version", - log_lines))[0] - local_major = int(local_major) - local_minor = int(local_minor) - remote_major = int(remote_major) - remote_minor = int(remote_minor) - except: - ssh_log.close() - return - - if (remote_major < 7 or (remote_major == 7 and remote_minor < 8)) and \ - (local_major > 8 or (local_major == 8 and local_minor >= 8)): - logger.warning("The OpenSSH server version in the target VM %d.%d is too old. " - "Version incompatible with OpenSSH client version %d.%d. " - "Refer to https://bugzilla.mindrot.org/show_bug.cgi?id=3351 for more information.", - remote_major, remote_minor, local_major, local_minor) - - elif (local_major < 7 or (local_major == 7 and local_minor < 8)) and \ - (remote_major > 8 or (remote_major == 8 and remote_minor >= 8)): - logger.warning("The OpenSSH client version %d.%d is too old. " - "Version incompatible with OpenSSH server version %d.%d in the target VM. " - "Refer to https://bugzilla.mindrot.org/show_bug.cgi?id=3351 for more information.", - local_major, local_minor, remote_major, remote_minor) - ssh_log.close() +def _check_for_known_errors(error_message, delete_cert, log_lines): + # This connection fails when using our generated certificates. + # Only throw error if conection fails with AAD login. + if "Permission denied (publickey)." in error_message and delete_cert: + # pylint: disable=bare-except + # pylint: disable=too-many-boolean-expressions + # Check if OpenSSH client and server versions are incompatible + try: + regex = 'OpenSSH.*_([0-9]+)\\.([0-9]+)' + local_major, local_minor = re.findall(regex, log_lines[0])[0] + remote_version_line = file_utils.get_line_that_contains("remote software version", log_lines) + remote_major, remote_minor = re.findall(regex, remote_version_line)[0] + local_major = int(local_major) + local_minor = int(local_minor) + remote_major = int(remote_major) + remote_minor = int(remote_minor) + except: + return + + if (remote_major < 7 or (remote_major == 7 and remote_minor < 8)) and \ + (local_major > 8 or (local_major == 8 and local_minor >= 8)): + logger.warning("The OpenSSH server version in the target VM %d.%d is too old. " + "Version incompatible with OpenSSH client version %d.%d. " + "Refer to https://bugzilla.mindrot.org/show_bug.cgi?id=3351 for more information.", + remote_major, remote_minor, local_major, local_minor) + + elif (local_major < 7 or (local_major == 7 and local_minor < 8)) and \ + (remote_major > 8 or (remote_major == 8 and remote_minor >= 8)): + logger.warning("The OpenSSH client version %d.%d is too old. " + "Version incompatible with OpenSSH server version %d.%d in the target VM. " + "Refer to https://bugzilla.mindrot.org/show_bug.cgi?id=3351 for more information.", + local_major, local_minor, remote_major, remote_minor) + + regex = ("{\"level\":\"fatal\",\"msg\":\"sshproxy: error copying information from the connection: " + ".*\",\"time\":\".*\"}.*") + if re.search(regex, error_message): + logger.error("Please make sure SSH port is allowed using \"azcmagent config list\" in the target " + "Arc Server. Ensure SSHD is running on the target machine.\n") def get_ssh_client_path(ssh_command="ssh", ssh_client_folder=None): @@ -263,96 +301,17 @@ def get_ssh_client_path(ssh_command="ssh", ssh_client_folder=None): return ssh_path -def do_cleanup(delete_keys, delete_cert, cert_file, private_key, public_key, log_file=None, wait=False): - if log_file: - t0 = time.time() - match = False - while (time.time() - t0) < const.CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS and not match: - time.sleep(const.CLEANUP_TIME_INTERVAL_IN_SECONDS) - # pylint: disable=bare-except - # pylint: disable=anomalous-backslash-in-string - try: - with open(log_file, 'r', encoding='utf-8') as ssh_client_log: - log_text = ssh_client_log.read() - # The "debug1:..." message doesn't seems to exist in OpenSSH 3.9 - match = ("debug1: Authentication succeeded" in log_text or - re.search("^Authenticated to .*\n", log_text, re.MULTILINE)) - ssh_client_log.close() - except: - # If there is an exception, wait for a little bit and try again - time.sleep(const.CLEANUP_TIME_INTERVAL_IN_SECONDS) - - elif wait: - # if we are not checking the logs, but still want to wait for connection before deleting files - time.sleep(const.CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS) - - if delete_keys and private_key: +def do_cleanup(delete_keys, delete_cert, delete_credentials, cert_file, private_key, public_key): + if (delete_keys or delete_credentials) and private_key: file_utils.delete_file(private_key, f"Couldn't delete private key {private_key}. ", True) if delete_keys and public_key: file_utils.delete_file(public_key, f"Couldn't delete public key {public_key}. ", True) - if delete_cert and cert_file: + if (delete_cert or delete_credentials) and cert_file: file_utils.delete_file(cert_file, f"Couldn't delete certificate {cert_file}. ", True) - - -def _start_cleanup(cert_file, private_key_file, public_key_file, delete_credentials, delete_keys, - delete_cert, ssh_arg_list): - log_file = None - cleanup_process = None - if delete_keys or delete_cert or delete_credentials: - if '-E' not in ssh_arg_list and set(['-v', '-vv', '-vvv']).isdisjoint(ssh_arg_list): - # If the user either provides his own client log file (-E) or - # wants the client log messages to be printed to the console (-vvv/-vv/-v), - # we should not use the log files to check for connection success. - if cert_file: - log_dir = os.path.dirname(cert_file) - elif private_key_file: - log_dir = os.path.dirname(private_key_file) - log_file_name = 'ssh_client_log_' + str(os.getpid()) - log_file = os.path.join(log_dir, log_file_name) - ssh_arg_list = ['-E', log_file, '-v'] + ssh_arg_list - # Create a new process that will wait until the connection is established and then delete keys. - cleanup_process = mp.Process(target=do_cleanup, args=(delete_keys or delete_credentials, - delete_cert or delete_credentials, - cert_file, private_key_file, public_key_file, - log_file, True)) - cleanup_process.start() - - return log_file, ssh_arg_list, cleanup_process - - -def _terminate_cleanup(delete_keys, delete_cert, delete_credentials, cleanup_process, cert_file, - private_key_file, public_key_file, log_file, connection_status): - try: - if connection_status and connection_status.stderr: - if connection_status.returncode != 0: - # Check if stderr is a proxy error - regex = ("{\"level\":\"fatal\",\"msg\":\"sshproxy: error copying information from the connection: " - ".*\",\"time\":\".*\"}.*") - if re.search(regex, connection_status.stderr): - logger.error("Please make sure SSH port is allowed using \"azcmagent config list\" in the target " - "Arc Server. Ensure SSHD is running on the target machine.") - print(connection_status.stderr) - finally: - if delete_keys or delete_cert or delete_credentials: - if cleanup_process and cleanup_process.is_alive(): - cleanup_process.terminate() - # wait for process to terminate - t0 = time.time() - while cleanup_process.is_alive() and (time.time() - t0) < const.CLEANUP_AWAIT_TERMINATION_IN_SECONDS: - time.sleep(1) - - if log_file and os.path.isfile(log_file): - _print_error_messages_from_ssh_log(log_file, connection_status, delete_cert) - - # Make sure all files have been properly removed. - do_cleanup(delete_keys or delete_credentials, delete_cert or delete_credentials, - cert_file, private_key_file, public_key_file) - if log_file: - file_utils.delete_file(log_file, f"Couldn't delete temporary log file {log_file}. ", True) - if delete_keys: - # This is only true if keys were generated, so they must be in a temp folder. - temp_dir = os.path.dirname(cert_file) - file_utils.delete_folder(temp_dir, f"Couldn't delete temporary folder {temp_dir}", True) + if delete_keys and cert_file: + # This is only true if keys were generated, so they must be in a temp folder. + temp_dir = os.path.dirname(cert_file) + file_utils.delete_folder(temp_dir, f"Couldn't delete temporary folder {temp_dir}", True) def _issue_config_cleanup_warning(delete_cert, delete_keys, is_arc, cert_file, relay_info_path, ssh_client_folder): diff --git a/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py index dc37b42e864..27c9f45e54d 100644 --- a/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py +++ b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py @@ -12,43 +12,81 @@ from azext_ssh import ssh_utils from azext_ssh import ssh_info -class SSHUtilsTests(unittest.TestCase): - @mock.patch.object(ssh_utils, '_start_cleanup') - @mock.patch.object(ssh_utils, '_terminate_cleanup') + +class SSHUtilsTests(unittest.TestCase): + @mock.patch.object(ssh_utils, 'do_cleanup') + @mock.patch.object(ssh_utils, '_read_ssh_logs') @mock.patch.object(ssh_utils, 'get_ssh_client_path') - @mock.patch('subprocess.run') + @mock.patch('subprocess.Popen') @mock.patch('os.environ.copy') @mock.patch('platform.system') - def test_start_ssh_connection_compute(self, mock_system, mock_copy_env, mock_call, mock_path, mock_terminatecleanup, mock_startcleanup): + def test_start_ssh_connection_compute_aad_windows(self, mock_system, mock_copy_env, mock_call, mock_path, mock_read, mock_cleanup): - op_info = ssh_info.SSHSession("rg", "vm", "ip", None, None, False, "user", None, "port", None, ['arg1', 'arg2', 'arg3'], False, "Microsof.Compute", None, None, False) + op_info = ssh_info.SSHSession("rg", "vm", "ip", None, None, False, "user", None, "port", None, ['arg1', 'arg2', 'arg3'], False, "Microsof.Compute/virtualMachines", None, None, False) op_info.public_key_file = "pub" op_info.private_key_file = "priv" op_info.cert_file = "cert" op_info.ssh_client_folder = "client" + ssh_process = mock.Mock() + ssh_process.poll.return_value = 0 + mock_system.return_value = 'Windows' mock_call.return_value = 0 mock_path.return_value = 'ssh' + mock_call.return_value = ssh_process mock_copy_env.return_value = {'var1':'value1', 'var2':'value2', 'var3':'value3'} - mock_startcleanup.return_value = 'log', ['arg1', 'arg2', 'arg3', '-E', 'log', '-v'], 'cleanup process' - expected_command = ['ssh', 'ip', '-l', 'user', '-i', 'priv', '-o', 'CertificateFile=\"cert\"', '-p', 'port', 'arg1', 'arg2', 'arg3', '-E', 'log', '-v'] + expected_command = ['ssh', 'ip', '-l', 'user', '-i', 'priv', '-o', 'CertificateFile=\"cert\"', '-p', 'port', '-v', 'arg1', 'arg2', 'arg3'] expected_env = {'var1':'value1', 'var2':'value2', 'var3':'value3'} ssh_utils.start_ssh_connection(op_info, True, True) mock_path.assert_called_once_with('ssh', 'client') - mock_startcleanup.assert_called_with('cert', 'priv', 'pub', False, True, True, ['arg1', 'arg2', 'arg3']) - mock_call.assert_called_once_with(expected_command, shell=True, env=expected_env, stderr=mock.ANY, encoding='utf-8') - mock_terminatecleanup.assert_called_once_with(True, True, False, 'cleanup process', 'cert', 'priv', 'pub', 'log', 0) - - @mock.patch.object(ssh_utils, '_terminate_cleanup') + mock_call.assert_called_once_with(expected_command, stderr=mock.ANY, env=expected_env, encoding='utf-8') + mock_read.assert_called_once_with(ssh_process, op_info, True, True) + mock_cleanup.assert_called_once_with(True, True, False, 'cert', 'priv', 'pub') + + @mock.patch.object(ssh_utils, 'do_cleanup') + @mock.patch.object(ssh_utils, '_wait_to_delete_credentials') + @mock.patch.object(ssh_utils, 'get_ssh_client_path') + @mock.patch('subprocess.Popen') @mock.patch('os.environ.copy') + @mock.patch('platform.system') + def test_start_ssh_connection_compute_local_linux(self, mock_system, mock_copy_env, mock_call, mock_path, mock_wait, mock_cleanup): + + op_info = ssh_info.SSHSession("rg", "vm", "ip", None, None, False, "user", None, "port", None, ['arg1', 'arg2', 'arg3'], False, "Microsof.Compute", None, None, False) + op_info.public_key_file = "pub" + op_info.private_key_file = "priv" + op_info.cert_file = "cert" + op_info.ssh_client_folder = "client" + + ssh_process = mock.Mock() + ssh_process.poll.return_value = 0 + + mock_system.return_value = 'Linux' + mock_call.return_value = 0 + mock_path.return_value = 'ssh' + mock_call.return_value = ssh_process + mock_copy_env.return_value = {'var1':'value1', 'var2':'value2', 'var3':'value3'} + expected_command = ['ssh', 'ip', '-l', 'user', '-i', 'priv', '-o', 'CertificateFile=\"cert\"', '-p', 'port', 'arg1', 'arg2', 'arg3'] + expected_env = {'var1':'value1', 'var2':'value2', 'var3':'value3'} + + ssh_utils.start_ssh_connection(op_info, False, False) + + mock_path.assert_called_once_with('ssh', 'client') + mock_call.assert_called_once_with(expected_command, env=expected_env, encoding='utf-8') + mock_wait.assert_called_once_with(ssh_process, op_info, False, False) + mock_cleanup.assert_called_once_with(False, False, False, 'cert', 'priv', 'pub') + + + @mock.patch.object(ssh_utils, 'do_cleanup') + @mock.patch.object(ssh_utils, '_read_ssh_logs') @mock.patch.object(ssh_utils, 'get_ssh_client_path') - @mock.patch('subprocess.run') + @mock.patch('os.environ.copy') + @mock.patch('subprocess.Popen') @mock.patch('azext_ssh.custom.connectivity_utils.format_relay_info_string') @mock.patch('platform.system') - def test_start_ssh_connection_arc(self, mock_system, mock_relay_str, mock_call, mock_path, mock_copy_env, mock_terminatecleanup): + def test_start_ssh_connection_arc_aad_windows(self, mock_platform, mock_relay_str, mock_call, mock_copy_env, mock_path, mock_read, mock_cleanup): op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, "user", None, "port", None, ['arg1'], False, "Microsoft.HybridCompute", None, None, False) op_info.public_key_file = "pub" @@ -57,9 +95,49 @@ def test_start_ssh_connection_arc(self, mock_system, mock_relay_str, mock_call, op_info.ssh_client_folder = "client" op_info.proxy_path = "proxy" op_info.relay_info = "relay" + + ssh_process = mock.Mock() + ssh_process.poll.return_value = 0 - mock_system.return_value = 'Linux' - mock_call.return_value = 0 + mock_platform.return_value = 'Windows' + mock_call.return_value = ssh_process + mock_relay_str.return_value = 'relay_string' + mock_copy_env.return_value = {'var1':'value1', 'var2':'value2', 'var3':'value3'} + mock_path.return_value = 'ssh' + expected_command = ['ssh', 'vm', '-l', 'user', '-o', 'ProxyCommand=\"proxy\" -p port', '-i', 'priv', '-o', 'CertificateFile=\"cert\"', '-v', 'arg1'] + expected_env = {'var1':'value1', 'var2':'value2', 'var3':'value3', 'SSHPROXY_RELAY_INFO':'relay_string'} + + ssh_utils.start_ssh_connection(op_info, True, True) + + mock_relay_str.assert_called_once_with('relay') + mock_path.assert_called_once_with('ssh', 'client') + mock_call.assert_called_once_with(expected_command, stderr=mock.ANY, env=expected_env, encoding='utf-8') + mock_cleanup.assert_called_once_with(True, True, False, 'cert', 'priv', 'pub') + mock_read.assert_called_once_with(ssh_process, op_info, True, True) + + + @mock.patch.object(ssh_utils, 'do_cleanup') + @mock.patch.object(ssh_utils, '_wait_to_delete_credentials') + @mock.patch.object(ssh_utils, 'get_ssh_client_path') + @mock.patch('os.environ.copy') + @mock.patch('subprocess.Popen') + @mock.patch('azext_ssh.custom.connectivity_utils.format_relay_info_string') + @mock.patch('platform.system') + def test_start_ssh_connection_arc_local_linux(self, mock_platform, mock_relay_str, mock_call, mock_copy_env, mock_path, mock_wait, mock_cleanup): + + op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, "user", None, "port", None, ['arg1'], False, "Microsoft.HybridCompute", None, None, False) + op_info.public_key_file = "pub" + op_info.private_key_file = "priv" + op_info.cert_file = "cert" + op_info.ssh_client_folder = "client" + op_info.proxy_path = "proxy" + op_info.relay_info = "relay" + + ssh_process = mock.Mock() + ssh_process.poll.return_value = 0 + + mock_platform.return_value = 'Linux' + mock_call.return_value = ssh_process mock_relay_str.return_value = 'relay_string' mock_copy_env.return_value = {'var1':'value1', 'var2':'value2', 'var3':'value3'} mock_path.return_value = 'ssh' @@ -70,10 +148,11 @@ def test_start_ssh_connection_arc(self, mock_system, mock_relay_str, mock_call, mock_relay_str.assert_called_once_with('relay') mock_path.assert_called_once_with('ssh', 'client') - mock_call.assert_called_once_with(expected_command, shell=False, env=expected_env, stderr=mock.ANY, encoding='utf-8') - mock_terminatecleanup.assert_called_once_with(False, False, False, None, 'cert', 'priv', 'pub', None, 0) - - + mock_call.assert_called_once_with(expected_command, env=expected_env, encoding='utf-8') + mock_cleanup.assert_called_once_with(False, False, False, 'cert', 'priv', 'pub') + mock_wait.assert_called_once_with(ssh_process, op_info, False, False) + + @mock.patch.object(ssh_utils, '_issue_config_cleanup_warning') @mock.patch('os.path.abspath') def test_write_ssh_config_ip_and_vm_compute_append(self, mock_abspath, mock_warning):