From a019617cb42f76a8ee657f9bc5343b2215c84b63 Mon Sep 17 00:00:00 2001 From: kyleknap Date: Tue, 25 Oct 2022 14:06:07 -0700 Subject: [PATCH 1/4] Add support for configuration of SSO sessions This commit includes three new features: * Add support for configure SSO sessions as part of the `configure sso` command. When running this command, users will be prompted for a SSO session to start in which they can create a new SSO session or reuse an existing SSO session when configuring a profile. * Add `configure sso-session` command. This allows users to configure a SSO session without having to reconfigure a profile. * Add `--sso-session` argument to `sso login` command. This allows users to refresh the SSO access token explicitly through a configured SSO session instead of having to specify a profile that is configured to use the SSO session in order to refresh the token. --- awscli/customizations/configure/__init__.py | 10 +- awscli/customizations/configure/configure.py | 2 + awscli/customizations/configure/sso.py | 581 +++++++++++++++---- awscli/customizations/sso/login.py | 18 +- awscli/customizations/sso/utils.py | 46 +- 5 files changed, 503 insertions(+), 154 deletions(-) diff --git a/awscli/customizations/configure/__init__.py b/awscli/customizations/configure/__init__.py index 6055529251d4..ab0630533545 100644 --- a/awscli/customizations/configure/__init__.py +++ b/awscli/customizations/configure/__init__.py @@ -46,6 +46,10 @@ def profile_to_section(profile_name): """Converts a profile name to a section header to be used in the config.""" if profile_name == 'default': return profile_name - if any(c in _WHITESPACE for c in profile_name): - profile_name = shlex_quote(profile_name) - return 'profile %s' % profile_name + return get_section_header('profile', profile_name) + + +def get_section_header(section_type, section_name): + if any(c in _WHITESPACE for c in section_name): + section_name = shlex_quote(section_name) + return f'{section_type} {section_name}' diff --git a/awscli/customizations/configure/configure.py b/awscli/customizations/configure/configure.py index 51d3b82fc037..e1365d1b9c75 100644 --- a/awscli/customizations/configure/configure.py +++ b/awscli/customizations/configure/configure.py @@ -25,6 +25,7 @@ from awscli.customizations.configure.importer import ConfigureImportCommand from awscli.customizations.configure.listprofiles import ListProfilesCommand from awscli.customizations.configure.sso import ConfigureSSOCommand +from awscli.customizations.configure.sso import ConfigureSSOSessionCommand from awscli.customizations.configure.exportcreds import \ ConfigureExportCredentialsCommand @@ -82,6 +83,7 @@ class ConfigureCommand(BasicCommand): {'name': 'import', 'command_class': ConfigureImportCommand}, {'name': 'list-profiles', 'command_class': ListProfilesCommand}, {'name': 'sso', 'command_class': ConfigureSSOCommand}, + {'name': 'sso-session', 'command_class': ConfigureSSOSessionCommand}, {'name': 'export-credentials', 'command_class': ConfigureExportCredentialsCommand}, ] diff --git a/awscli/customizations/configure/sso.py b/awscli/customizations/configure/sso.py index 7c9c73231304..19442634c53d 100644 --- a/awscli/customizations/configure/sso.py +++ b/awscli/customizations/configure/sso.py @@ -10,9 +10,14 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +import collections +import itertools +import json import os import logging +import re +import colorama from botocore import UNSIGNED from botocore.config import Config from botocore.configprovider import ConstantProvider @@ -20,39 +25,90 @@ from botocore.utils import is_valid_endpoint_url from prompt_toolkit import prompt as ptk_prompt +from prompt_toolkit.application import get_app from prompt_toolkit.completion import WordCompleter +from prompt_toolkit.formatted_text import FormattedText +from prompt_toolkit.styles import Style from prompt_toolkit.validation import Validator from prompt_toolkit.validation import ValidationError from awscli.customizations.utils import uni_print -from awscli.customizations.commands import BasicCommand -from awscli.customizations.configure import profile_to_section +from awscli.customizations.configure import ( + profile_to_section, get_section_header, +) from awscli.customizations.configure.writer import ConfigFileWriter from awscli.customizations.wizard.ui.selectmenu import select_menu from awscli.customizations.sso.utils import ( - do_sso_login, PrintOnlyHandler, LOGIN_ARGS, + do_sso_login, parse_sso_registration_scopes, PrintOnlyHandler, LOGIN_ARGS, + BaseSSOCommand, ) from awscli.formatter import CLI_OUTPUT_FORMATS logger = logging.getLogger(__name__) +_CMD_PROMPT_USAGE = ( + 'To keep an existing value, hit enter when prompted for the value. When ' + 'you are prompted for information, the current value will be displayed in ' + '[brackets]. If the config item has no value, it is displayed as ' + '[None] or omitted entirely.\n\n' +) +_CONFIG_EXTRA_INFO = ( + 'Note: The configuration is saved in the shared configuration file. ' + 'By default, ``~/.aws/config``. For more information, see the ' + '"Configuring the AWS CLI to use AWS Single Sign-On" section in the AWS ' + 'CLI User Guide:' + '\n\nhttps://docs.aws.amazon.com/cli/latest/userguide/cli-configure-sso.html' +) + -class StartUrlValidator(Validator): +class ValidatorWithDefault(Validator): def __init__(self, default=None): - super(StartUrlValidator, self).__init__() + super(ValidatorWithDefault, self).__init__() self._default = default + def _raise_validation_error(self, document, message): + index = len(document.text) + raise ValidationError(index, message) + + +class StartUrlValidator(ValidatorWithDefault): def validate(self, document): # If there's a default, allow an empty prompt if not document.text and self._default: return if not is_valid_endpoint_url(document.text): - index = len(document.text) - raise ValidationError(index, 'Not a valid Start URL') + self._raise_validation_error(document, 'Not a valid Start URL') + + +class RequiredInputValidator(ValidatorWithDefault): + def validate(self, document): + if document.text or self._default: + return + self._raise_validation_error(document, 'A value is required') + + +class ScopesValidator(ValidatorWithDefault): + def validate(self, document): + # If there's a default, allow an empty prompt + if not document.text and self._default: + return + if not self._is_comma_separated_list(document.text): + self._raise_validation_error( + document, 'Scope values must be separated by commas') + + def _is_comma_separated_list(self, value): + value.strip() + scopes = value.split(',') + for scope in scopes: + if re.findall(r'\s', scope.strip()): + return False + return True class PTKPrompt(object): + _DEFAULT_PROMPT_FORMAT = '{prompt_text} [{current_value}]: ' + def __init__(self, prompter=None): if prompter is None: prompter = ptk_prompt @@ -61,29 +117,35 @@ def __init__(self, prompter=None): def _create_completer(self, completions): if completions is None: completions = [] + completer_kwargs = { + 'words': completions, + 'pattern': re.compile(r'\S+') + } if isinstance(completions, dict): - meta_dict = completions - completions = list(meta_dict.keys()) - completer = WordCompleter( - completions, - sentence=True, - meta_dict=meta_dict, - ) - else: - completer = WordCompleter(completions, sentence=True) - return completer + completer_kwargs['meta_dict'] = completions + completer_kwargs['words'] = list(completions.keys()) + return WordCompleter(**completer_kwargs) def get_value(self, current_value, prompt_text='', - completions=None, validator=None): - completer = self._create_completer(completions) - prompt_string = u'{} [{}]: '.format(prompt_text, current_value) - response = self._prompter( - prompt_string, - validator=validator, - validate_while_typing=False, - completer=completer, - complete_while_typing=True, + completions=None, validator=None, toolbar=None, + prompt_fmt=None): + if prompt_fmt is None: + prompt_fmt = self._DEFAULT_PROMPT_FORMAT + prompt_string = prompt_fmt.format( + prompt_text=prompt_text, + current_value=current_value ) + prompter_kwargs = { + 'validator': validator, + 'validate_while_typing': False, + 'completer': self._create_completer(completions), + 'complete_while_typing': True, + 'style': self._get_prompt_style(), + } + if toolbar: + prompter_kwargs['bottom_toolbar'] = toolbar + prompter_kwargs['refresh_interval'] = 0.2 + response = self._prompter(prompt_string, **prompter_kwargs) # Strip any extra white space response = response.strip() if not response: @@ -91,6 +153,13 @@ def get_value(self, current_value, prompt_text='', response = current_value return response + def _get_prompt_style(self): + return Style.from_dict( + { + 'bottom-toolbar': 'noreverse', + } + ) + def display_account(account): """Converts an SSO account response into a display string. @@ -111,67 +180,219 @@ def display_account(account): return account_template.format(**account) -class ConfigureSSOCommand(BasicCommand): +class SSOSessionConfigurationPrompter: + _DEFAULT_SSO_SCOPE = 'sso:account:access' + _KNOWN_SSO_SCOPES = { + 'sso:account:access': ( + 'Grants access to AWS IAM Identity Center accounts and permission ' + 'sets' + ) + } + + def __init__(self, botocore_session, prompter): + self._botocore_session = botocore_session + self._prompter = prompter + self._sso_sessions = self._botocore_session.full_config.get( + 'sso_sessions', {}) + self._sso_session = None + self.sso_session_config = {} + + @property + def sso_session(self): + return self._sso_session + + @sso_session.setter + def sso_session(self, value): + self._sso_session = value + self.sso_session_config = self._sso_sessions.get( + self._sso_session, {}).copy() + + def prompt_for_sso_session(self, required=True): + prompt_text = 'SSO session name' + prompt_fmt = None + validator_cls = None + if required: + validator_cls = RequiredInputValidator + if not self.sso_session: + prompt_fmt = f'{prompt_text}: ' + if not required: + prompt_fmt = f'{prompt_text} (Recommended): ' + sso_session = self._prompt_for( + 'sso_session', prompt_text, + completions=sorted(self._sso_sessions), + toolbar=self._get_sso_session_toolbar, + validator_cls=validator_cls, + prompt_fmt=prompt_fmt, + current_value=self.sso_session, + ) + self.sso_session = sso_session + return sso_session + + def prompt_for_sso_start_url(self): + return self._prompt_for( + 'sso_start_url', 'SSO start URL', + completions=self._get_potential_start_urls(), + validator_cls=StartUrlValidator, + ) + + def prompt_for_sso_region(self): + return self._prompt_for( + 'sso_region', 'SSO region', + completions=self._get_potential_sso_regions(), + validator_cls=RequiredInputValidator, + ) + + def prompt_for_sso_registration_scopes(self): + if 'sso_registration_scopes' not in self.sso_session_config: + self.sso_session_config['sso_registration_scopes'] = \ + self._DEFAULT_SSO_SCOPE + raw_scopes = self._prompt_for( + 'sso_registration_scopes', 'SSO registration scopes', + completions=self._get_potential_sso_registrations_scopes(), + validator_cls=ScopesValidator, + ) + return parse_sso_registration_scopes(raw_scopes) + + def _prompt_for(self, config_name, text, + completions=None, validator_cls=None, + toolbar=None, prompt_fmt=None, current_value=None): + if current_value is None: + current_value = self.sso_session_config.get(config_name) + validator = None + if validator_cls: + validator = validator_cls(current_value) + value = self._prompter.get_value( + current_value, text, + completions=completions, + validator=validator, + toolbar=toolbar, + prompt_fmt=prompt_fmt + ) + if value: + self.sso_session_config[config_name] = value + return value + + def _get_sso_session_toolbar(self): + current_input = get_app().current_buffer.document.text + if current_input in self._sso_sessions: + selected_sso_config = self._sso_sessions[current_input] + return FormattedText([ + ('', self._get_toolbar_border()), + ('', '\n'), + ('bold', f'Configuration for SSO session: {current_input}\n\n'), + ('', json.dumps(selected_sso_config, indent=2)), + ]) + + def _get_toolbar_border(self): + horizontal_line_char = '\u2500' + return horizontal_line_char * get_app().output.get_size().columns + + def _get_potential_start_urls(self): + profiles = self._botocore_session.full_config.get('profiles', {}) + configs_to_search = itertools.chain( + profiles.values(), + self._sso_sessions.values() + ) + potential_start_urls = set() + for config_to_search in configs_to_search: + if 'sso_start_url' in config_to_search: + start_url = config_to_search['sso_start_url'] + potential_start_urls.add(start_url) + return list(potential_start_urls) + + def _get_potential_sso_regions(self): + return self._botocore_session.get_available_regions('sso-oidc') + + def _get_potential_sso_registrations_scopes(self): + potential_scopes = self._KNOWN_SSO_SCOPES.copy() + scopes_to_sessions = self._get_previously_used_scopes_to_sso_sessions() + for scope, sso_sessions in scopes_to_sessions.items(): + if scope not in potential_scopes: + potential_scopes[scope] = ( + f'Used in SSO sessions: {", ".join(sso_sessions)}' + ) + return potential_scopes + + def _get_previously_used_scopes_to_sso_sessions(self): + scopes_to_sessions = collections.defaultdict(list) + for sso_session, sso_session_config in self._sso_sessions.items(): + if 'sso_registration_scopes' in sso_session_config: + parsed_scopes = parse_sso_registration_scopes( + sso_session_config['sso_registration_scopes'] + ) + for parsed_scope in parsed_scopes: + scopes_to_sessions[parsed_scope].append(sso_session) + return scopes_to_sessions + + +class BaseSSOConfigurationCommand(BaseSSOCommand): + def __init__(self, session, prompter=None, config_writer=None): + super(BaseSSOConfigurationCommand, self).__init__(session) + if prompter is None: + prompter = PTKPrompt() + self._prompter = prompter + if config_writer is None: + config_writer = ConfigFileWriter() + self._config_writer = config_writer + self._sso_sessions = self._session.full_config.get('sso_sessions', {}) + self._sso_session_prompter = SSOSessionConfigurationPrompter( + botocore_session=session, prompter=self._prompter, + ) + + def _write_sso_configuration(self): + self._update_section( + section_header=get_section_header( + 'sso-session', self._sso_session_prompter.sso_session), + new_values=self._sso_session_prompter.sso_session_config + ) + + def _update_section(self, section_header, new_values): + config_path = self._session.get_config_variable('config_file') + config_path = os.path.expanduser(config_path) + new_values['__section__'] = section_header + self._config_writer.update_config(new_values, config_path) + + +class ConfigureSSOCommand(BaseSSOConfigurationCommand): NAME = 'sso' SYNOPSIS = ('aws configure sso [--profile profile-name]') DESCRIPTION = ( 'The ``aws configure sso`` command interactively prompts for the ' 'configuration values required to create a profile that sources ' - 'temporary AWS credentials from AWS Single Sign-On. To keep an ' - 'existing value, hit enter when prompted for the value. When you ' - 'are prompted for information, the current value will be displayed in ' - '[brackets]. If the config item has no value, it is displayed as ' - '[None]. When providing the ``--profile`` parameter the named profile ' + 'temporary AWS credentials from AWS Single Sign-On.\n\n' + f'{_CMD_PROMPT_USAGE}' + 'When providing the ``--profile`` parameter the named profile ' 'will be created or updated. When a profile is not explicitly set ' - 'the profile name will be prompted for.' - '\n\nNote: The configuration is saved in the shared configuration ' - 'file. By default, ``~/.aws/config``.' - 'For more information, see the "Configuring the AWS CLI to use AWS ' - 'Single Sign-On" section in the AWS CLI User Guide:' - '\n\nhttps://docs.aws.amazon.com/cli/latest/userguide/cli-configure-sso.html' + 'the profile name will be prompted for.\n\n' + f'{_CONFIG_EXTRA_INFO}' ) # TODO: Add CLI parameters to skip prompted values, --start-url, etc. ARG_TABLE = LOGIN_ARGS def __init__(self, session, prompter=None, selector=None, config_writer=None, sso_token_cache=None, sso_login=None): - super(ConfigureSSOCommand, self).__init__(session) - if prompter is None: - prompter = PTKPrompt() - self._prompter = prompter + super(ConfigureSSOCommand, self).__init__( + session, prompter=prompter, config_writer=config_writer) if selector is None: selector = select_menu self._selector = selector - if config_writer is None: - config_writer = ConfigFileWriter() if sso_login is None: sso_login = do_sso_login self._sso_login = sso_login - self._config_writer = config_writer self._sso_token_cache = sso_token_cache - self._new_values = {} + self._new_profile_config_values = {} self._original_profile_name = self._session.profile try: - self._config = self._session.get_scoped_config() + self._profile_config = self._session.get_scoped_config() except ProfileNotFound: - self._config = {} + self._profile_config = {} + self._set_sso_session_if_configured_in_profile() - def _prompt_for(self, config_name, text, - completions=None, validator_cls=None): - current_value = self._config.get(config_name) - if validator_cls is None: - validator = None - else: - validator = validator_cls(current_value) - new_value = self._prompter.get_value( - current_value, text, - completions=completions, - validator=validator, - ) - if new_value: - self._new_values[config_name] = new_value - return new_value + def _set_sso_session_if_configured_in_profile(self): + if 'sso_session' in self._profile_config: + self._sso_session_prompter.sso_session = \ + self._profile_config['sso_session'] def _handle_single_account(self, accounts): sso_account_id = accounts[0]['accountId'] @@ -186,7 +407,8 @@ def _handle_multiple_accounts(self, accounts): 'There are {} AWS accounts available to you.\n' ) uni_print(available_accounts_msg.format(len(accounts))) - selected_account = self._selector(accounts, display_account) + selected_account = self._selector( + accounts, display_format=display_account) sso_account_id = selected_account['accountId'] return sso_account_id @@ -204,7 +426,7 @@ def _prompt_for_account(self, sso, sso_token): else: sso_account_id = self._handle_multiple_accounts(accounts) uni_print('Using the account ID {}\n'.format(sso_account_id)) - self._new_values['sso_account_id'] = sso_account_id + self._new_profile_config_values['sso_account_id'] = sso_account_id return sso_account_id def _handle_single_role(self, roles): @@ -238,57 +460,43 @@ def _prompt_for_role(self, sso, sso_token, sso_account_id): else: sso_role_name = self._handle_multiple_roles(roles) uni_print('Using the role name "{}"\n'.format(sso_role_name)) - self._new_values['sso_role_name'] = sso_role_name + self._new_profile_config_values['sso_role_name'] = sso_role_name return sso_role_name - def _prompt_for_profile(self, sso_account_id, sso_role_name): + def _prompt_for_profile(self, sso_account_id=None, sso_role_name=None): if self._original_profile_name: profile_name = self._original_profile_name else: - default_profile = '{}-{}'.format(sso_role_name, sso_account_id) text = 'CLI profile name' - profile_name = self._prompter.get_value(default_profile, text) + default_profile = None + if sso_account_id and sso_role_name: + default_profile = '{}-{}'.format(sso_role_name, sso_account_id) + validator = RequiredInputValidator(default_profile) + profile_name = self._prompter.get_value( + default_profile, text, validator=validator) return profile_name - def _get_potential_start_urls(self): - profiles = self._session.full_config.get('profiles', []) - potential_start_urls = set() - for profile, config in profiles.items(): - if 'sso_start_url' in config: - start_url = config['sso_start_url'] - potential_start_urls.add(start_url) - return list(potential_start_urls) - - def _prompt_for_start_url(self): - potential_start_urls = self._get_potential_start_urls() - start_url = self._prompt_for( - 'sso_start_url', 'SSO start URL', - completions=potential_start_urls, - validator_cls=StartUrlValidator, - ) - return start_url - - def _get_potential_sso_regions(self): - return self._session.get_available_regions('sso-oidc') - - def _prompt_for_sso_region(self): - potential_sso_regions = self._get_potential_sso_regions() - sso_region = self._prompt_for( - 'sso_region', 'SSO Region', - completions=potential_sso_regions, - ) - return sso_region - def _prompt_for_cli_default_region(self): # TODO: figure out a way to get a list of reasonable client regions - return self._prompt_for('region', 'CLI default client Region') + return self._prompt_for_profile_config( + 'region', 'CLI default client Region') def _prompt_for_cli_output_format(self): - return self._prompt_for( + return self._prompt_for_profile_config( 'output', 'CLI default output format', completions=list(CLI_OUTPUT_FORMATS.keys()), ) + def _prompt_for_profile_config(self, config_name, text, completions=None): + current_value = self._profile_config.get(config_name) + new_value = self._prompter.get_value( + current_value, text, + completions=completions, + ) + if new_value: + self._new_profile_config_values[config_name] = new_value + return new_value + def _unset_session_profile(self): # The profile provided to the CLI as --profile may not exist. # This means we cannot use the session as is to create clients. @@ -302,28 +510,28 @@ def _unset_session_profile(self): def _run_main(self, parsed_args, parsed_globals): self._unset_session_profile() - start_url = self._prompt_for_start_url() - sso_region = self._prompt_for_sso_region() on_pending_authorization = None if parsed_args.no_browser: on_pending_authorization = PrintOnlyHandler() + sso_registration_args = self._prompt_for_sso_registration_args() sso_token = self._sso_login( self._session, - sso_region, - start_url, token_cache=self._sso_token_cache, on_pending_authorization=on_pending_authorization, + **sso_registration_args ) # Construct an SSO client to explore the accounts / roles client_config = Config( signature_version=UNSIGNED, - region_name=sso_region, + region_name=sso_registration_args['sso_region'], ) sso = self._session.create_client('sso', config=client_config) - sso_account_id = self._prompt_for_account(sso, sso_token) - sso_role_name = self._prompt_for_role(sso, sso_token, sso_account_id) + sso_account_id, sso_role_name = self._prompt_for_sso_account_and_role( + sso, sso_token + ) + configured_for_aws_credentials = all((sso_account_id, sso_role_name)) # General CLI configuration self._prompt_for_cli_default_region() @@ -331,20 +539,153 @@ def _run_main(self, parsed_args, parsed_globals): profile_name = self._prompt_for_profile(sso_account_id, sso_role_name) - usage_msg = ( - '\nTo use this profile, specify the profile name using ' - '--profile, as shown:\n\n' - 'aws s3 ls --profile {}\n' - ) - uni_print(usage_msg.format(profile_name)) - self._write_new_config(profile_name) + self._print_conclusion(configured_for_aws_credentials, profile_name) return 0 + def _prompt_for_sso_registration_args(self): + sso_session = self._sso_session_prompter.prompt_for_sso_session( + required=False) + if sso_session is None: + self._warn_configuring_using_legacy_format() + return self._prompt_for_registration_args_with_legacy_format() + else: + self._set_sso_session_in_profile_config(sso_session) + if sso_session in self._sso_sessions: + return self._get_sso_registration_args_from_sso_config( + sso_session) + else: + return self._prompt_for_registration_args_for_new_sso_session( + sso_session=sso_session + ) + + def _prompt_for_registration_args_with_legacy_format(self): + self._store_sso_session_prompter_answers_to_profile_config() + self._set_sso_session_defaults_from_profile_config() + start_url, sso_region = self._prompt_for_sso_start_url_and_sso_region() + return { + 'start_url': start_url, + 'sso_region': sso_region + } + + def _get_sso_registration_args_from_sso_config(self, sso_session): + sso_config = self._get_sso_session_config(sso_session) + return { + 'session_name': sso_session, + 'start_url': sso_config['sso_start_url'], + 'sso_region': sso_config['sso_region'], + 'registration_scopes': sso_config.get('registration_scopes') + } + + def _prompt_for_registration_args_for_new_sso_session(self, sso_session): + self._set_sso_session_defaults_from_profile_config() + start_url, sso_region = self._prompt_for_sso_start_url_and_sso_region() + scopes = self._sso_session_prompter.prompt_for_sso_registration_scopes() + return { + 'session_name': sso_session, + 'start_url': start_url, + 'sso_region': sso_region, + 'registration_scopes': scopes, + # We force refresh for any new SSO sessions to ensure we are not + # using any cached tokens from any previous of attempts to + # create/authenticate a new SSO session as part of the configure + # sso flow. + 'force_refresh': True + } + + def _store_sso_session_prompter_answers_to_profile_config(self): + # Wire the SSO session prompter to set config values to the + # dictionary used for writing to the profile section + self._sso_session_prompter.sso_session_config = \ + self._new_profile_config_values + + def _set_sso_session_in_profile_config(self, sso_session): + self._new_profile_config_values['sso_session'] = sso_session + + def _set_sso_session_defaults_from_profile_config(self): + # This is to ensure the SSO session prompter pulls in existing + # SSO configuration as part of the prompt if a profile was explicitly + # provided that already had SSO configuration + if 'sso_start_url' in self._profile_config: + self._sso_session_prompter.sso_session_config['sso_start_url'] = \ + self._profile_config['sso_start_url'] + if 'sso_region' in self._profile_config: + self._sso_session_prompter.sso_session_config['sso_region'] = \ + self._profile_config['sso_region'] + + def _prompt_for_sso_start_url_and_sso_region(self): + start_url = self._sso_session_prompter.prompt_for_sso_start_url() + sso_region = self._sso_session_prompter.prompt_for_sso_region() + return start_url, sso_region + + def _warn_configuring_using_legacy_format(self): + uni_print( + f'{colorama.Style.BRIGHT}WARNING: Configuring using legacy format ' + f'(e.g. without an SSO session).\n' + f'Consider re-running "configure sso" command and providing ' + f'a session name.\n{colorama.Style.RESET_ALL}' + ) + + def _prompt_for_sso_account_and_role(self, sso, sso_token): + sso_account_id = None + sso_role_name = None + try: + sso_account_id = self._prompt_for_account(sso, sso_token) + sso_role_name = self._prompt_for_role( + sso, sso_token, sso_account_id) + except sso.exceptions.UnauthorizedException as e: + uni_print( + 'Unable to list AWS accounts and/or roles. ' + 'Skipping configuring AWS credential provider for profile.\n' + ) + return sso_account_id, sso_role_name + def _write_new_config(self, profile): - config_path = self._session.get_config_variable('config_file') - config_path = os.path.expanduser(config_path) - if self._new_values: - section = profile_to_section(profile) - self._new_values['__section__'] = section - self._config_writer.update_config(self._new_values, config_path) + if self._new_profile_config_values: + profile_section = profile_to_section(profile) + self._update_section( + profile_section, self._new_profile_config_values) + if self._sso_session_prompter.sso_session: + self._write_sso_configuration() + + def _print_conclusion(self, configured_for_aws_credentials, profile_name): + if configured_for_aws_credentials: + msg = ( + '\nTo use this profile, specify the profile name using ' + '--profile, as shown:\n\n' + 'aws s3 ls --profile {}\n' + ) + else: + msg = 'Successfully configured SSO for profile: {}\n' + uni_print(msg.format(profile_name)) + + +class ConfigureSSOSessionCommand(BaseSSOConfigurationCommand): + NAME = 'sso-session' + SYNOPSIS = ('aws configure sso-session') + DESCRIPTION = ( + 'The ``aws configure sso-session`` command interactively prompts for ' + 'the configuration values required to create a SSO session. ' + 'The SSO session can then be associated to a profile to retrieve ' + 'SSO access tokens and AWS credentials.\n\n' + f'{_CMD_PROMPT_USAGE}' + f'{_CONFIG_EXTRA_INFO}' + ) + + def _run_main(self, parsed_args, parsed_globals): + self._sso_session_prompter.prompt_for_sso_session() + self._sso_session_prompter.prompt_for_sso_start_url() + self._sso_session_prompter.prompt_for_sso_region() + self._sso_session_prompter.prompt_for_sso_registration_scopes() + self._write_sso_configuration() + self._print_configuration_success() + return 0 + + def _print_configuration_success(self): + sso_session = self._sso_session_prompter.sso_session + uni_print( + f'\nCompleted configuring SSO session: {sso_session}\n' + f'Run the following to login and refresh access token for ' + f'this session:\n\n' + f'aws sso login --sso-session {sso_session}\n' + ) diff --git a/awscli/customizations/sso/login.py b/awscli/customizations/sso/login.py index c921d06a5373..bf7f5489726d 100644 --- a/awscli/customizations/sso/login.py +++ b/awscli/customizations/sso/login.py @@ -23,14 +23,24 @@ class LoginCommand(BaseSSOCommand): 'credentials. To login, the requested profile must have first been ' 'setup using ``aws configure sso``. Each time the ``login`` command ' 'is called, a new SSO access token will be retrieved. Please note ' - 'that only one login session can be active for a given SSO Start URL ' + 'that only one login session can be active for a given SSO Session ' 'and creating multiple profiles does not allow for multiple users to ' - 'be authenticated against the same SSO Start URL.' + 'be authenticated against the same SSO Session.' ) - ARG_TABLE = LOGIN_ARGS + ARG_TABLE = LOGIN_ARGS + [ + { + 'name': 'sso-session', + 'help_text': ( + 'An explicit SSO session to use to login. By default, this ' + 'command will login using the SSO session configured as part ' + 'of the requested profile and generally does not require this ' + 'argument to be set.' + ) + } + ] def _run_main(self, parsed_args, parsed_globals): - sso_config = self._get_sso_config() + sso_config = self._get_sso_config(sso_session=parsed_args.sso_session) on_pending_authorization = None if parsed_args.no_browser: on_pending_authorization = PrintOnlyHandler() diff --git a/awscli/customizations/sso/utils.py b/awscli/customizations/sso/utils.py index ae809d8031a5..775379c04780 100644 --- a/awscli/customizations/sso/utils.py +++ b/awscli/customizations/sso/utils.py @@ -77,6 +77,15 @@ def do_sso_login(session, sso_region, start_url, token_cache=None, ) +def parse_sso_registration_scopes(raw_scopes): + parsed_scopes = [] + for scope in raw_scopes.split(','): + scope = scope.strip() + if scope: + parsed_scopes.append(scope) + return parsed_scopes + + def open_browser_with_original_ld_path(url): with original_ld_library_path(): webbrowser.open_new_tab(url) @@ -148,25 +157,16 @@ class BaseSSOCommand(BasicCommand): 'sso_region', ] - def _get_sso_config(self): + def _get_sso_config(self, sso_session=None): scoped_config = self._session.get_scoped_config() - sso_session_config = self._get_sso_session_config(scoped_config) - if sso_session_config: - return sso_session_config - return self._get_legacy_sso_config(scoped_config) - - def _get_sso_session_config(self, scoped_config): - if 'sso_session' not in scoped_config: - return None - - for config_var in self._REQUIRED_SSO_CONFIG_VARS: - if config_var in scoped_config: - raise InvalidSSOConfigError( - 'Inline SSO configuration and sso_session cannot be ' - 'configured on the same profile.' - ) - - session_name = scoped_config['sso_session'] + if sso_session is None: + sso_session = scoped_config.get('sso_session') + if sso_session: + return self._get_sso_session_config(sso_session) + else: + return self._get_legacy_sso_config(scoped_config) + + def _get_sso_session_config(self, session_name): full_config = self._session.full_config if session_name not in full_config.get('sso_sessions', {}): raise InvalidSSOConfigError( @@ -179,7 +179,7 @@ def _get_sso_session_config(self, scoped_config): scopes_var = 'sso_registration_scopes' if scopes_var in session_config: raw_scopes = session_config[scopes_var] - parsed_scopes = self._parse_registration_scopes(raw_scopes) + parsed_scopes = parse_sso_registration_scopes(raw_scopes) sso_config['registration_scopes'] = parsed_scopes if missing: @@ -190,14 +190,6 @@ def _get_sso_session_config(self, scoped_config): return sso_config - def _parse_registration_scopes(self, raw_scopes): - parsed_scopes = [] - for scope in raw_scopes.split(','): - scope = scope.strip() - if scope: - parsed_scopes.append(scope) - return parsed_scopes - def _get_legacy_sso_config(self, scoped_config): sso_config, missing = self._get_required_config_vars(scoped_config) if missing: From bf161f342c2ce48db7b3405955b89d4c057574f3 Mon Sep 17 00:00:00 2001 From: kyleknap Date: Tue, 25 Oct 2022 14:10:30 -0700 Subject: [PATCH 2/4] Update tests for new SSO configuration features This commit both adds the new test cases and refactors the old test cases to leverage pytest fixtures and make setup logic more reusable across test cases. --- tests/functional/sso/__init__.py | 16 +- tests/functional/sso/test_login.py | 28 +- .../unit/customizations/configure/__init__.py | 1 + .../unit/customizations/configure/test_sso.py | 2244 ++++++++++++++--- tests/unit/customizations/sso/test_utils.py | 21 + tests/utils/botocore/__init__.py | 8 + 6 files changed, 1941 insertions(+), 377 deletions(-) diff --git a/tests/functional/sso/__init__.py b/tests/functional/sso/__init__.py index 5dd426dfc826..98036f9173b8 100644 --- a/tests/functional/sso/__init__.py +++ b/tests/functional/sso/__init__.py @@ -67,12 +67,16 @@ def get_legacy_config(self): ) return content - def get_sso_session_config(self, session_name): - content = ( - f'[default]\n' - f'sso_session={session_name}\n' - f'sso_role_name={self.role_name}\n' - f'sso_account_id={self.account}\n' + def get_sso_session_config(self, session_name, include_profile=True): + content = '' + if include_profile: + content += ( + f'[default]\n' + f'sso_session={session_name}\n' + f'sso_role_name={self.role_name}\n' + f'sso_account_id={self.account}\n' + ) + content += ( f'[sso-session {session_name}]\n' f'sso_start_url={self.start_url}\n' f'sso_region={self.sso_region}\n' diff --git a/tests/functional/sso/test_login.py b/tests/functional/sso/test_login.py index dd8af1126f4f..e24c47591716 100644 --- a/tests/functional/sso/test_login.py +++ b/tests/functional/sso/test_login.py @@ -195,6 +195,19 @@ def test_login_sso_session(self): expected_token=self.access_token, ) + def test_login_sso_with_explicit_sso_session_arg(self): + content = self.get_sso_session_config( + 'test-session', include_profile=False) + self.set_config_file_content(content=content) + self.add_oidc_workflow_responses(self.access_token) + self.run_cmd('sso login --sso-session test-session') + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_token( + start_url=self.start_url, + session_name='test-session', + expected_token=self.access_token, + ) + def test_login_sso_session_with_scopes(self): self.registration_scopes = ['sso:foo', 'sso:bar'] content = self.get_sso_session_config('test-session') @@ -211,21 +224,6 @@ def test_login_sso_session_with_scopes(self): self.assertEqual(operation.name, 'RegisterClient') self.assertEqual(params.get('scopes'), self.registration_scopes) - def test_login_sso_session_and_legacy_config_errors(self): - content = self.get_legacy_config() - content += ( - f'sso_session=test\n' - f'[sso-session test]\n' - f'sso_start_url={self.start_url}\n' - f'sso_region={self.sso_region}\n' - ) - self.set_config_file_content(content=content) - _, stderr, _ = self.run_cmd('sso login', expected_rc=253) - self.assertIn( - 'cannot be configured on the same profile', - stderr - ) - def test_login_sso_session_missing_config(self): content = ( f'[default]\n' diff --git a/tests/unit/customizations/configure/__init__.py b/tests/unit/customizations/configure/__init__.py index 6cfcd74ccf8b..60cfce39f4ad 100644 --- a/tests/unit/customizations/configure/__init__.py +++ b/tests/unit/customizations/configure/__init__.py @@ -26,6 +26,7 @@ def __init__(self, all_variables, profile_does_not_exist=False, self.variables = all_variables self.profile_does_not_exist = profile_does_not_exist self.config = {} + self.full_config = {} if config_file_vars is None: config_file_vars = {} self.config_file_vars = config_file_vars diff --git a/tests/unit/customizations/configure/test_sso.py b/tests/unit/customizations/configure/test_sso.py index 0ef24dbd8fe8..b515e763726b 100644 --- a/tests/unit/customizations/configure/test_sso.py +++ b/tests/unit/customizations/configure/test_sso.py @@ -10,29 +10,1575 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +import argparse +import dataclasses +import json +import typing + import mock from datetime import datetime, timedelta + +import prompt_toolkit +import pytest from dateutil.tz import tzlocal from prompt_toolkit import prompt as ptk_prompt from prompt_toolkit.document import Document +from prompt_toolkit.validation import Validator from prompt_toolkit.validation import DummyValidator from prompt_toolkit.validation import ValidationError -from botocore.session import Session from botocore.stub import Stubber -from botocore.exceptions import ProfileNotFound from awscli.testutils import unittest from awscli.customizations.configure.sso import display_account -from awscli.customizations.configure.sso import select_menu from awscli.customizations.configure.sso import PTKPrompt +from awscli.customizations.configure.sso import SSOSessionConfigurationPrompter from awscli.customizations.configure.sso import ConfigureSSOCommand +from awscli.customizations.configure.sso import ConfigureSSOSessionCommand from awscli.customizations.configure.sso import StartUrlValidator -from awscli.customizations.configure.writer import ConfigFileWriter +from awscli.customizations.configure.sso import RequiredInputValidator +from awscli.customizations.configure.sso import ScopesValidator +from awscli.customizations.sso.utils import parse_sso_registration_scopes from awscli.customizations.sso.utils import do_sso_login, PrintOnlyHandler from awscli.formatter import CLI_OUTPUT_FORMATS +from tests import StubbedSession + + +@pytest.fixture +def aws_config(tmp_path): + return tmp_path / "config" + + +@pytest.fixture +def env(aws_config): + env_vars = { + "AWS_DEFAULT_REGION": "us-west-2", + "AWS_ACCESS_KEY_ID": "access_key", + "AWS_SECRET_ACCESS_KEY": "secret_key", + "AWS_CONFIG_FILE": aws_config, + "AWS_SHARED_CREDENTIALS_FILE": "", + } + with mock.patch("os.environ", env_vars): + yield env_vars + + +@pytest.fixture +def access_token(): + return "access.token.string" + + +@pytest.fixture +def account_id(): + return "0123456789" + + +@pytest.fixture +def role_name(): + return "roleA" + + +@pytest.fixture +def sso_session_name(): + return "dev" + + +@pytest.fixture +def scopes(): + return "scope-1, scope-2" + + +@pytest.fixture +def default_sso_scope(): + return "sso:account:access" + + +@pytest.fixture +def existing_profile_name(): + return "existing-profile" + + +@pytest.fixture +def existing_sso_session(): + return "existing-sso-session" + + +@pytest.fixture +def existing_start_url(): + return "https://existing-start-url" + + +@pytest.fixture +def existing_sso_region(): + return "existing-sso-region" + + +@pytest.fixture +def existing_scopes(): + return "existing-scope-1, existing-scope-2" + + +@pytest.fixture +def existing_region(): + return "existing-region" + + +@pytest.fixture +def existing_output(): + return "existing-output" + + +@pytest.fixture +def botocore_session(env): + return StubbedSession() + + +@pytest.fixture +def all_sso_oidc_regions(botocore_session): + return botocore_session.get_available_regions("sso-oidc") + + +@pytest.fixture +def sso_stubber_factory(env, botocore_session): + def create_sso_stubber(session=None): + if session is None: + session = botocore_session + sso_client = session.create_client("sso") + stubber = Stubber(sso_client) + stubber.activate() + return stubber + + return create_sso_stubber + + +@pytest.fixture +def sso_stubber(sso_stubber_factory): + return sso_stubber_factory() + + +@pytest.fixture +def stub_sso_list_accounts(sso_stubber, access_token): + def _do_stub_list_accounts(accounts, override_sso_stubber=None): + stubber = sso_stubber + if override_sso_stubber is not None: + stubber = override_sso_stubber + stubber.add_response( + "list_accounts", + service_response={ + "accountList": accounts, + }, + expected_params={"accessToken": access_token}, + ) + + return _do_stub_list_accounts + + +@pytest.fixture +def stub_sso_list_roles(sso_stubber, access_token): + def _do_stub_list_accounts( + role_names, expected_account_id, override_sso_stubber=None + ): + stubber = sso_stubber + if override_sso_stubber is not None: + stubber = override_sso_stubber + stubber.add_response( + "list_account_roles", + service_response={ + "roleList": [ + {"roleName": role_name} for role_name in role_names + ], + }, + expected_params={ + "accountId": expected_account_id, + "accessToken": access_token, + }, + ) + + return _do_stub_list_accounts + + +@pytest.fixture +def stub_simple_single_item_sso_responses( + sso_stubber, access_token, stub_sso_list_accounts, stub_sso_list_roles +): + def _do_stub_simple_single_item_sso_responses( + account_id, role_name, override_sso_stubber=None + ): + stub_sso_list_accounts( + accounts=[ + { + "accountId": account_id, + "emailAddress": "account@site.com", + } + ], + override_sso_stubber=override_sso_stubber, + ) + stub_sso_list_roles( + role_names=[role_name], + expected_account_id=account_id, + override_sso_stubber=override_sso_stubber, + ) + + return _do_stub_simple_single_item_sso_responses + + +@pytest.fixture +def stub_sso_authorization_error(sso_stubber): + def _do_stub_authorization_error(override_sso_stubber=None): + stubber = sso_stubber + if override_sso_stubber is not None: + stubber = override_sso_stubber + stubber.add_client_error( + "list_accounts", service_error_code="UnauthorizedException" + ) + + return _do_stub_authorization_error + + +@pytest.fixture() +def ptk_stubber(): + return PTKStubber() + + +@pytest.fixture +def prompter(ptk_stubber): + return PTKPrompt(prompter=ptk_stubber.prompt) + + +@pytest.fixture +def sso_config_prompter_factory(env, botocore_session, prompter): + def create_sso_config_prompter(session=None, prompt=None): + if session is None: + session = botocore_session + if prompt is None: + prompt = prompter + return SSOSessionConfigurationPrompter( + botocore_session=session, prompter=prompt + ) + + return create_sso_config_prompter + + +@pytest.fixture +def sso_config_prompter(sso_config_prompter_factory): + return sso_config_prompter_factory() + + +@pytest.fixture +def selector(ptk_stubber): + return ptk_stubber.select_menu + + +@pytest.fixture +def mock_ptk_app(): + mock_app = mock.Mock(spec=prompt_toolkit.application.DummyApplication()) + with prompt_toolkit.application.current.set_app(mock_app): + yield mock_app + + +@pytest.fixture +def mock_do_sso_login(): + login_mock = mock.Mock(spec=do_sso_login) + login_mock.return_value = { + "accessToken": "access.token.string", + "expiresAt": datetime.now(tzlocal()) + timedelta(hours=24), + } + return login_mock + + +@pytest.fixture +def sso_cmd_factory( + env, botocore_session, prompter, mock_do_sso_login, selector +): + def create_sso_cmd(**override_kwargs): + kwargs = { + "session": botocore_session, + "prompter": prompter, + "sso_login": mock_do_sso_login, + "selector": selector, + } + kwargs.update(**override_kwargs) + return ConfigureSSOCommand(**kwargs) + + return create_sso_cmd + + +@pytest.fixture +def sso_cmd(sso_cmd_factory): + return sso_cmd_factory() + + +@pytest.fixture +def sso_session_cmd_factory(env, botocore_session, prompter): + def create_sso_session_cmd(**override_kwargs): + kwargs = {"session": botocore_session, "prompter": prompter} + kwargs.update(**override_kwargs) + return ConfigureSSOSessionCommand(**kwargs) + + return create_sso_session_cmd + + +@pytest.fixture +def sso_session_cmd(sso_session_cmd_factory): + return sso_session_cmd_factory() + + +@pytest.fixture +def args(): + return [] + + +@pytest.fixture +def parsed_globals(): + return argparse.Namespace() + + +@pytest.fixture +def start_url_prompt(): + return StartUrlPrompt(answer="https://starturl", expected_default=None) + + +@pytest.fixture +def sso_region_prompt(): + return SSORegionPrompt(answer="us-west-2", expected_default=None) + + +@pytest.fixture +def scopes_prompt(scopes, default_sso_scope): + return ScopesPrompt(answer=scopes, expected_default=default_sso_scope) + + +@pytest.fixture +def account_id_select(account_id): + selected_account = { + "accountId": account_id, + "emailAddress": "account@site.com", + } + return SelectMenu( + answer=selected_account, + expected_choices=[ + selected_account, + {"accountId": "1234567890", "emailAddress": "account2@site.com"}, + ], + ) + + +@pytest.fixture +def role_name_select(role_name): + return SelectMenu(answer=role_name, expected_choices=[role_name, "roleB"]) + + +@pytest.fixture +def region_prompt(): + return RegionPrompt(answer="us-west-2", expected_default=None) + + +@pytest.fixture +def output_prompt(): + return OutputPrompt(answer="json", expected_default=None) + + +@pytest.fixture +def profile_prompt(role_name, account_id): + return ProfilePrompt( + answer="dev", expected_default=f"{role_name}-{account_id}" + ) + + +@pytest.fixture +def configure_sso_legacy_inputs( + start_url_prompt, + sso_region_prompt, + account_id_select, + role_name_select, + region_prompt, + output_prompt, + profile_prompt, +): + return UserInputs( + session_prompt=RecommendedSessionPrompt(answer=""), + start_url_prompt=start_url_prompt, + sso_region_prompt=sso_region_prompt, + account_id_select=account_id_select, + role_name_select=role_name_select, + region_prompt=region_prompt, + output_prompt=output_prompt, + profile_prompt=profile_prompt, + ) + + +@pytest.fixture +def configure_sso_legacy_with_existing_defaults_inputs( + configure_sso_legacy_inputs, + existing_start_url, + existing_sso_region, + existing_region, + existing_output, +): + inputs = configure_sso_legacy_inputs + inputs.start_url_prompt.expected_default = existing_start_url + inputs.sso_region_prompt.expected_default = existing_sso_region + inputs.region_prompt.expected_default = existing_region + inputs.output_prompt.expected_default = existing_output + return inputs + + +@pytest.fixture +def configure_sso_using_new_session_inputs( + start_url_prompt, + sso_region_prompt, + scopes_prompt, + account_id_select, + role_name_select, + region_prompt, + output_prompt, + profile_prompt, + sso_session_name, +): + return UserInputs( + session_prompt=RecommendedSessionPrompt(answer=sso_session_name), + start_url_prompt=start_url_prompt, + sso_region_prompt=sso_region_prompt, + scopes_prompt=scopes_prompt, + account_id_select=account_id_select, + role_name_select=role_name_select, + region_prompt=region_prompt, + output_prompt=output_prompt, + profile_prompt=profile_prompt, + ) + + +@pytest.fixture() +def configure_sso_using_existing_session_inputs( + account_id_select, + role_name_select, + region_prompt, + output_prompt, + profile_prompt, + existing_sso_session, +): + return UserInputs( + session_prompt=RecommendedSessionPrompt(answer=existing_sso_session), + account_id_select=account_id_select, + role_name_select=role_name_select, + region_prompt=region_prompt, + output_prompt=output_prompt, + profile_prompt=profile_prompt, + ) + + +@pytest.fixture +def configure_sso_with_existing_defaults_inputs( + configure_sso_using_existing_session_inputs, + existing_sso_session, + existing_region, + existing_output, + sso_session_name, +): + inputs = configure_sso_using_existing_session_inputs + inputs.session_prompt = SessionWithDefaultPrompt( + answer=sso_session_name, expected_default=existing_sso_session + ) + inputs.region_prompt.expected_default = existing_region + inputs.output_prompt.expected_default = existing_output + return inputs + + +@pytest.fixture +def configure_sso_using_new_session_from_legacy_profile_inputs( + configure_sso_using_new_session_inputs, + sso_session_name, + existing_start_url, + existing_sso_region, + existing_region, + existing_output, +): + inputs = configure_sso_using_new_session_inputs + inputs.clear_answers() + inputs.session_prompt.answer = sso_session_name + inputs.start_url_prompt.expected_default = existing_start_url + inputs.sso_region_prompt.expected_default = existing_sso_region + inputs.region_prompt.expected_default = existing_region + inputs.output_prompt.expected_default = existing_output + return inputs + + +@pytest.fixture() +def configure_sso_session_inputs( + sso_session_name, start_url_prompt, sso_region_prompt, scopes_prompt +): + return UserInputs( + session_prompt=RequiredSessionPrompt(answer=sso_session_name), + start_url_prompt=start_url_prompt, + sso_region_prompt=sso_region_prompt, + scopes_prompt=scopes_prompt, + ) + + +@pytest.fixture +def configure_sso_session_with_existing_defaults_inputs( + configure_sso_session_inputs, + existing_start_url, + existing_sso_region, + existing_scopes, +): + inputs = configure_sso_session_inputs + inputs.start_url_prompt.expected_default = existing_start_url + inputs.sso_region_prompt.expected_default = existing_sso_region + inputs.scopes_prompt.expected_default = existing_scopes + return inputs + + +@pytest.fixture +def aws_config_lines_for_existing_legacy_profile( + existing_profile_name, + existing_start_url, + existing_sso_region, + existing_region, + existing_output, + account_id, + role_name, +): + return [ + f"[profile {existing_profile_name}]", + f"sso_start_url = {existing_start_url}", + f"sso_region = {existing_sso_region}", + f"sso_account_id = {account_id}", + f"sso_role_name = {role_name}", + f"region = {existing_region}", + f"output = {existing_output}", + ] + + +@pytest.fixture +def aws_config_lines_for_existing_sso_session( + existing_sso_session, + existing_start_url, + existing_sso_region, + existing_scopes, +): + return [ + f"[sso-session {existing_sso_session}]", + f"sso_start_url = {existing_start_url}", + f"sso_region = {existing_sso_region}", + f"sso_registration_scopes = {existing_scopes}", + ] + + +@pytest.fixture +def aws_config_lines_for_existing_profile_and_session( + existing_profile_name, + existing_sso_session, + existing_region, + existing_output, + account_id, + role_name, + aws_config_lines_for_existing_sso_session, +): + return [ + f"[profile {existing_profile_name}]", + f"sso_session = {existing_sso_session}", + f"sso_account_id = {account_id}", + f"sso_role_name = {role_name}", + f"region = {existing_region}", + f"output = {existing_output}", + ] + aws_config_lines_for_existing_sso_session + + +@dataclasses.dataclass +class UserInput: + answer: typing.Any + + +@dataclasses.dataclass +class Prompt(UserInput): + expected_validator_cls: typing.Optional[Validator] = None + expected_completions: typing.Optional[typing.List[str]] = None + _expected_message: typing.Optional[str] = dataclasses.field( + init=False, repr=False, default=None + ) + + @property + def expected_message(self): + return self._expected_message + + @expected_message.setter + def expected_message(self, value): + self._expected_message = value + + +@dataclasses.dataclass +class PromptWithDefault(Prompt): + expected_default: typing.Any = None + msg_format: str = dataclasses.field(init=False) + + @property + def expected_message(self): + if self._expected_message is None: + self._expected_message = self.msg_format.format( + default=self.expected_default + ) + return self._expected_message + + +@dataclasses.dataclass +class StartUrlPrompt(PromptWithDefault): + msg_format: str = dataclasses.field( + init=False, default="SSO start URL [{default}]: " + ) + expected_validator_cls: typing.Optional[Validator] = StartUrlValidator + + +@dataclasses.dataclass +class SSORegionPrompt(PromptWithDefault): + msg_format: str = dataclasses.field( + init=False, default="SSO region [{default}]: " + ) + expected_validator_cls: typing.Optional[Validator] = RequiredInputValidator + + +@dataclasses.dataclass +class ScopesPrompt(PromptWithDefault): + msg_format: str = dataclasses.field( + init=False, default="SSO registration scopes [{default}]: " + ) + expected_validator_cls: typing.Optional[Validator] = ScopesValidator + + +@dataclasses.dataclass +class RequiredSessionPrompt(Prompt): + expected_validator_cls: typing.Optional[Validator] = RequiredInputValidator + + def __post_init__(self): + super().__init__( + answer=self.answer, + expected_validator_cls=self.expected_validator_cls, + ) + self.expected_message = "SSO session name: " + + +@dataclasses.dataclass +class RecommendedSessionPrompt(Prompt): + def __post_init__(self): + super().__init__(answer=self.answer) + self.expected_message = "SSO session name (Recommended): " + + +@dataclasses.dataclass +class SessionWithDefaultPrompt(PromptWithDefault): + msg_format: str = dataclasses.field( + init=False, default="SSO session name [{default}]: " + ) + + +@dataclasses.dataclass +class RegionPrompt(PromptWithDefault): + msg_format: str = dataclasses.field( + init=False, default="CLI default client Region [{default}]: " + ) + + +@dataclasses.dataclass +class OutputPrompt(PromptWithDefault): + msg_format: str = dataclasses.field( + init=False, default="CLI default output format [{default}]: " + ) + + +@dataclasses.dataclass +class ProfilePrompt(PromptWithDefault): + msg_format: str = dataclasses.field( + init=False, default="CLI profile name [{default}]: " + ) + expected_validator_cls: typing.Optional[Validator] = RequiredInputValidator + + +@dataclasses.dataclass +class SelectMenu(UserInput): + expected_choices: typing.Optional[typing.List[typing.Any]] = None + + +@dataclasses.dataclass +class UserInputs: + session_prompt: typing.Optional[Prompt] = None + start_url_prompt: typing.Optional[StartUrlPrompt] = None + sso_region_prompt: typing.Optional[SSORegionPrompt] = None + scopes_prompt: typing.Optional[ScopesPrompt] = None + account_id_select: typing.Optional[SelectMenu] = None + role_name_select: typing.Optional[SelectMenu] = None + region_prompt: typing.Optional[RegionPrompt] = None + output_prompt: typing.Optional[OutputPrompt] = None + profile_prompt: typing.Optional[ProfilePrompt] = None + + def get_expected_inputs(self): + expected_inputs = [] + for possible_input_field in dataclasses.fields(self): + possible_input = getattr(self, possible_input_field.name) + if possible_input is not None: + expected_inputs.append(possible_input) + return expected_inputs + + def clear_answers(self): + for user_input in self.get_expected_inputs(): + user_input.answer = "" + + def skip_account_and_role_selection(self): + self.account_id_select = None + self.role_name_select = None + + def skip_profile_prompt(self): + self.profile_prompt = None + + +class PTKStubber: + _ALLOWED_PROMPT_KWARGS = { + "validator", + "validate_while_typing", + "completer", + "complete_while_typing", + "bottom_toolbar", + "refresh_interval", + "style", + } + _ALLOWED_SELECT_MENU_KWARGS = { + "display_format", + "max_height", + } + + def __init__(self, user_inputs=None): + if user_inputs is None: + user_inputs = UserInputs() + self.user_inputs = user_inputs + self._expected_inputs = None + + def prompt(self, message, **kwargs): + self._initialize_expected_inputs_if_needed() + self._validate_kwargs(kwargs, self._ALLOWED_PROMPT_KWARGS) + if not self._expected_inputs: + raise AssertionError( + f'Received prompt with no stubbed answer: "{message}"' + ) + prompt = self._expected_inputs.pop(0) + assert isinstance( + prompt, Prompt + ), f'Did not receive user input of type Prompt for: "{message}"' + if prompt.expected_message is not None: + assert message == prompt.expected_message, ( + f"Prompt does not match expected " + f'prompt for answer: "{prompt}"' + ) + if prompt.expected_validator_cls: + assert isinstance( + kwargs.get("validator"), prompt.expected_validator_cls + ) + if prompt.expected_completions is not None: + provided_completer = kwargs.get("completer") + assert provided_completer is not None, ( + f"Expected completions but no completer was provided for " + f"prompt: {prompt}" + ) + assert provided_completer.words == prompt.expected_completions + return prompt.answer + + def select_menu(self, items, **kwargs): + self._initialize_expected_inputs_if_needed() + self._validate_kwargs(kwargs, self._ALLOWED_SELECT_MENU_KWARGS) + if not self._expected_inputs: + raise AssertionError( + f'Received select_menu with no stubbed answer: "{items}"' + ) + select_menu = self._expected_inputs.pop(0) + assert isinstance( + select_menu, SelectMenu + ), f'Did not receive user input of type SelectMenu for: "{items}"' + if select_menu.expected_choices is not None: + assert items == select_menu.expected_choices, ( + f"Choices does not match expected select_menu choices " + f'for answer: "{select_menu.answer}"' + ) + return select_menu.answer + + def _initialize_expected_inputs_if_needed(self): + if self._expected_inputs is None: + self._expected_inputs = self.user_inputs.get_expected_inputs() + + def _validate_kwargs(self, provided_kwargs, allowed_kwargs): + assert set(provided_kwargs).issubset( + allowed_kwargs + ), "Arguments provided does not matched allowed keyword arguments" + + +def write_aws_config(aws_config, lines): + with open(aws_config, "w") as f: + content = "\n".join(lines) + f.write(content + "\n") + + +def assert_aws_config(aws_config, expected_lines): + with open(aws_config, "r") as f: + assert f.read().splitlines() == expected_lines + + +class TestConfigureSSOCommand: + def assert_do_sso_login_call( + self, + mock_do_sso_login, + botocore_session, + expected_sso_region, + expected_start_url, + expected_session_name=None, + expected_scopes=None, + expected_auth_handler_cls=None, + expected_force_refresh=None, + ): + expected_kwargs = { + "sso_region": expected_sso_region, + "start_url": expected_start_url, + "on_pending_authorization": None, + "token_cache": None, + } + if expected_session_name is not None: + expected_kwargs["session_name"] = expected_session_name + if expected_scopes is not None: + expected_kwargs["registration_scopes"] = expected_scopes + if expected_auth_handler_cls: + expected_kwargs["on_pending_authorization"] = mock.ANY + if expected_force_refresh is not None: + expected_kwargs["force_refresh"] = expected_force_refresh + + mock_do_sso_login.assert_called_with( + botocore_session, **expected_kwargs + ) + + if expected_auth_handler_cls: + _, _, login_kwargs = mock_do_sso_login.mock_calls[0] + auth_handler = login_kwargs["on_pending_authorization"] + assert isinstance(auth_handler, expected_auth_handler_cls) + + def test_legacy_configure_sso_flow( + self, + sso_cmd, + ptk_stubber, + aws_config, + stub_sso_list_roles, + stub_sso_list_accounts, + mock_do_sso_login, + botocore_session, + args, + parsed_globals, + configure_sso_legacy_inputs, + capsys, + ): + inputs = configure_sso_legacy_inputs + selected_account_id = inputs.account_id_select.answer["accountId"] + ptk_stubber.user_inputs = inputs + stub_sso_list_accounts(inputs.account_id_select.expected_choices) + stub_sso_list_roles( + inputs.role_name_select.expected_choices, + expected_account_id=selected_account_id, + ) + + sso_cmd(args, parsed_globals) + self.assert_do_sso_login_call( + mock_do_sso_login, + botocore_session, + expected_sso_region=inputs.sso_region_prompt.answer, + expected_start_url=inputs.start_url_prompt.answer, + ) + assert_aws_config( + aws_config, + expected_lines=[ + f"[profile {inputs.profile_prompt.answer}]", + f"sso_start_url = {inputs.start_url_prompt.answer}", + f"sso_region = {inputs.sso_region_prompt.answer}", + f"sso_account_id = {selected_account_id}", + f"sso_role_name = {inputs.role_name_select.answer}", + f"region = {inputs.region_prompt.answer}", + f"output = {inputs.output_prompt.answer}", + ], + ) + stdout = capsys.readouterr().out + assert "WARNING: Configuring using legacy format" in stdout + assert f"aws s3 ls --profile {inputs.profile_prompt.answer}" in stdout + + def test_single_account_single_role_flow_no_browser( + self, + sso_cmd, + ptk_stubber, + aws_config, + stub_simple_single_item_sso_responses, + mock_do_sso_login, + botocore_session, + parsed_globals, + configure_sso_legacy_inputs, + account_id, + role_name, + ): + inputs = configure_sso_legacy_inputs + inputs.skip_account_and_role_selection() + ptk_stubber.user_inputs = inputs + stub_simple_single_item_sso_responses(account_id, role_name) + + sso_cmd(["--no-browser"], parsed_globals) + self.assert_do_sso_login_call( + mock_do_sso_login, + botocore_session, + expected_sso_region=inputs.sso_region_prompt.answer, + expected_start_url=inputs.start_url_prompt.answer, + expected_auth_handler_cls=PrintOnlyHandler, + ) + assert_aws_config( + aws_config, + expected_lines=[ + f"[profile {inputs.profile_prompt.answer}]", + f"sso_start_url = {inputs.start_url_prompt.answer}", + f"sso_region = {inputs.sso_region_prompt.answer}", + f"sso_account_id = {account_id}", + f"sso_role_name = {role_name}", + f"region = {inputs.region_prompt.answer}", + f"output = {inputs.output_prompt.answer}", + ], + ) + + def test_single_account_single_role_flow( + self, + sso_cmd, + ptk_stubber, + aws_config, + stub_simple_single_item_sso_responses, + mock_do_sso_login, + botocore_session, + args, + parsed_globals, + configure_sso_legacy_inputs, + account_id, + role_name, + ): + inputs = configure_sso_legacy_inputs + inputs.skip_account_and_role_selection() + ptk_stubber.user_inputs = inputs + stub_simple_single_item_sso_responses(account_id, role_name) + + sso_cmd(args, parsed_globals) + self.assert_do_sso_login_call( + mock_do_sso_login, + botocore_session, + expected_sso_region=inputs.sso_region_prompt.answer, + expected_start_url=inputs.start_url_prompt.answer, + ) + assert_aws_config( + aws_config, + expected_lines=[ + f"[profile {inputs.profile_prompt.answer}]", + f"sso_start_url = {inputs.start_url_prompt.answer}", + f"sso_region = {inputs.sso_region_prompt.answer}", + f"sso_account_id = {account_id}", + f"sso_role_name = {role_name}", + f"region = {inputs.region_prompt.answer}", + f"output = {inputs.output_prompt.answer}", + ], + ) + + def test_no_accounts_flow_raises_error( + self, + sso_cmd, + ptk_stubber, + sso_stubber, + stub_sso_list_accounts, + args, + parsed_globals, + configure_sso_legacy_inputs, + ): + ptk_stubber.user_inputs = configure_sso_legacy_inputs + stub_sso_list_accounts([]) + with pytest.raises(RuntimeError): + sso_cmd(args, parsed_globals) + sso_stubber.assert_no_pending_responses() + + def test_no_roles_flow_raises_error( + self, + sso_cmd, + ptk_stubber, + sso_stubber, + stub_sso_list_accounts, + stub_sso_list_roles, + args, + parsed_globals, + configure_sso_legacy_inputs, + ): + only_account = configure_sso_legacy_inputs.account_id_select.answer + configure_sso_legacy_inputs.account_id_select = None + ptk_stubber.user_inputs = configure_sso_legacy_inputs + stub_sso_list_accounts([only_account]) + stub_sso_list_roles([], expected_account_id=only_account["accountId"]) + with pytest.raises(RuntimeError): + sso_cmd(args, parsed_globals) + sso_stubber.assert_no_pending_responses() + + def test_defaults_to_scoped_config( + self, + sso_cmd_factory, + ptk_stubber, + aws_config, + sso_stubber_factory, + stub_simple_single_item_sso_responses, + mock_do_sso_login, + args, + parsed_globals, + configure_sso_legacy_with_existing_defaults_inputs, + aws_config_lines_for_existing_legacy_profile, + account_id, + role_name, + existing_profile_name, + ): + write_aws_config( + aws_config, lines=aws_config_lines_for_existing_legacy_profile + ) + session = StubbedSession(profile=existing_profile_name) + + inputs = configure_sso_legacy_with_existing_defaults_inputs + inputs.skip_account_and_role_selection() + inputs.skip_profile_prompt() + inputs.clear_answers() + ptk_stubber.user_inputs = inputs + + sso_stubber = sso_stubber_factory(session) + stub_simple_single_item_sso_responses( + account_id, role_name, sso_stubber + ) + + sso_cmd = sso_cmd_factory(session=session) + sso_cmd(args, parsed_globals) + self.assert_do_sso_login_call( + mock_do_sso_login, + session, + expected_sso_region=inputs.sso_region_prompt.expected_default, + expected_start_url=inputs.start_url_prompt.expected_default, + ) + assert_aws_config( + aws_config, + expected_lines=aws_config_lines_for_existing_legacy_profile, + ) + + def test_handles_non_existent_profile( + self, + sso_cmd_factory, + ptk_stubber, + aws_config, + botocore_session, + stub_simple_single_item_sso_responses, + mock_do_sso_login, + args, + parsed_globals, + configure_sso_legacy_inputs, + account_id, + role_name, + ): + inputs = configure_sso_legacy_inputs + inputs.skip_account_and_role_selection() + inputs.skip_profile_prompt() + ptk_stubber.user_inputs = inputs + stub_simple_single_item_sso_responses(account_id, role_name) + + new_session = StubbedSession(profile="new-profile") + # We use the default session to create the stubbed clients because + # if we create the stubbed clients with a non-existent profile, we will + # get a ProfileNotFound error. So after the clients' creation we + # assign them to be used in the session using the new profile. + new_session.cached_clients.update(botocore_session.cached_clients) + new_session.client_stubs.update(botocore_session.client_stubs) + + sso_cmd = sso_cmd_factory(session=new_session) + sso_cmd(args, parsed_globals) + self.assert_do_sso_login_call( + mock_do_sso_login, + new_session, + expected_sso_region=inputs.sso_region_prompt.answer, + expected_start_url=inputs.start_url_prompt.answer, + ) + assert_aws_config( + aws_config, + expected_lines=[ + f"[profile new-profile]", + f"sso_start_url = {inputs.start_url_prompt.answer}", + f"sso_region = {inputs.sso_region_prompt.answer}", + f"sso_account_id = {account_id}", + f"sso_role_name = {role_name}", + f"region = {inputs.region_prompt.answer}", + f"output = {inputs.output_prompt.answer}", + ], + ) + + def test_cli_config_is_none_not_written( + self, + sso_cmd, + ptk_stubber, + aws_config, + botocore_session, + stub_simple_single_item_sso_responses, + args, + parsed_globals, + configure_sso_legacy_inputs, + account_id, + role_name, + ): + inputs = configure_sso_legacy_inputs + inputs.skip_account_and_role_selection() + inputs.region_prompt.answer = "" + inputs.output_prompt.answer = "" + ptk_stubber.user_inputs = inputs + stub_simple_single_item_sso_responses(account_id, role_name) + + sso_cmd(args, parsed_globals) + + assert_aws_config( + aws_config, + expected_lines=[ + f"[profile {inputs.profile_prompt.answer}]", + f"sso_start_url = {inputs.start_url_prompt.answer}", + f"sso_region = {inputs.sso_region_prompt.answer}", + f"sso_account_id = {account_id}", + f"sso_role_name = {role_name}", + ], + ) + + def test_prompts_suggest_values_from_profiles( + self, + sso_cmd_factory, + ptk_stubber, + aws_config, + sso_stubber_factory, + stub_simple_single_item_sso_responses, + aws_config_lines_for_existing_legacy_profile, + existing_start_url, + args, + parsed_globals, + configure_sso_legacy_inputs, + account_id, + role_name, + all_sso_oidc_regions, + ): + write_aws_config( + aws_config, lines=aws_config_lines_for_existing_legacy_profile + ) + session = StubbedSession() + + inputs = configure_sso_legacy_inputs + inputs.skip_account_and_role_selection() + inputs.start_url_prompt.expected_completions = [existing_start_url] + inputs.sso_region_prompt.expected_completions = all_sso_oidc_regions + inputs.output_prompt.expected_completions = list( + CLI_OUTPUT_FORMATS.keys() + ) + ptk_stubber.user_inputs = inputs + + sso_stubber = sso_stubber_factory(session) + stub_simple_single_item_sso_responses( + account_id, role_name, sso_stubber + ) + sso_cmd = sso_cmd_factory(session=session) + assert sso_cmd(args, parsed_globals) == 0 + + def test_configure_sso_with_new_sso_session( + self, + sso_cmd, + ptk_stubber, + aws_config, + stub_sso_list_roles, + stub_sso_list_accounts, + mock_do_sso_login, + botocore_session, + args, + parsed_globals, + configure_sso_using_new_session_inputs, + capsys, + ): + inputs = configure_sso_using_new_session_inputs + selected_account_id = inputs.account_id_select.answer["accountId"] + ptk_stubber.user_inputs = inputs + + stub_sso_list_accounts(inputs.account_id_select.expected_choices) + stub_sso_list_roles( + inputs.role_name_select.expected_choices, + expected_account_id=selected_account_id, + ) + + sso_cmd(args, parsed_globals) + self.assert_do_sso_login_call( + mock_do_sso_login, + botocore_session, + expected_session_name=inputs.session_prompt.answer, + expected_sso_region=inputs.sso_region_prompt.answer, + expected_start_url=inputs.start_url_prompt.answer, + expected_scopes=parse_sso_registration_scopes( + inputs.scopes_prompt.answer + ), + expected_force_refresh=True, + ) + assert_aws_config( + aws_config, + expected_lines=[ + f"[profile {inputs.profile_prompt.answer}]", + f"sso_session = {inputs.session_prompt.answer}", + f"sso_account_id = {selected_account_id}", + f"sso_role_name = {inputs.role_name_select.answer}", + f"region = {inputs.region_prompt.answer}", + f"output = {inputs.output_prompt.answer}", + f"[sso-session {inputs.session_prompt.answer}]", + f"sso_start_url = {inputs.start_url_prompt.answer}", + f"sso_region = {inputs.sso_region_prompt.answer}", + f"sso_registration_scopes = {inputs.scopes_prompt.answer}", + ], + ) + stdout = capsys.readouterr().out + assert "WARNING: Configuring using legacy format" not in stdout + assert f"aws s3 ls --profile {inputs.profile_prompt.answer}" in stdout + + def test_configure_sso_with_existing_sso_session( + self, + sso_cmd_factory, + ptk_stubber, + aws_config, + sso_stubber_factory, + stub_simple_single_item_sso_responses, + mock_do_sso_login, + args, + parsed_globals, + configure_sso_using_existing_session_inputs, + aws_config_lines_for_existing_sso_session, + account_id, + role_name, + existing_start_url, + existing_sso_region, + existing_scopes, + ): + write_aws_config( + aws_config, lines=aws_config_lines_for_existing_sso_session + ) + session = StubbedSession() + + inputs = configure_sso_using_existing_session_inputs + inputs.skip_account_and_role_selection() + ptk_stubber.user_inputs = inputs + + sso_stubber = sso_stubber_factory(session) + stub_simple_single_item_sso_responses( + account_id, role_name, sso_stubber + ) + sso_cmd = sso_cmd_factory(session=session) + sso_cmd(args, parsed_globals) + + self.assert_do_sso_login_call( + mock_do_sso_login, + session, + expected_session_name=inputs.session_prompt.answer, + expected_sso_region=existing_sso_region, + expected_start_url=existing_start_url, + expected_scopes=parse_sso_registration_scopes(existing_scopes), + ) + assert_aws_config( + aws_config, + expected_lines=aws_config_lines_for_existing_sso_session + + [ + f"[profile {inputs.profile_prompt.answer}]", + f"sso_session = {inputs.session_prompt.answer}", + f"sso_account_id = {account_id}", + f"sso_role_name = {role_name}", + f"region = {inputs.region_prompt.answer}", + f"output = {inputs.output_prompt.answer}", + ], + ) + + def test_configure_sso_reusing_existing_configuration( + self, + sso_cmd_factory, + ptk_stubber, + aws_config, + sso_stubber_factory, + stub_simple_single_item_sso_responses, + mock_do_sso_login, + args, + parsed_globals, + configure_sso_with_existing_defaults_inputs, + aws_config_lines_for_existing_profile_and_session, + account_id, + role_name, + existing_profile_name, + existing_start_url, + existing_sso_region, + existing_scopes, + ): + write_aws_config( + aws_config, lines=aws_config_lines_for_existing_profile_and_session + ) + session = StubbedSession(profile=existing_profile_name) + + inputs = configure_sso_with_existing_defaults_inputs + inputs.skip_account_and_role_selection() + inputs.clear_answers() + ptk_stubber.user_inputs = inputs + + sso_stubber = sso_stubber_factory(session) + stub_simple_single_item_sso_responses( + account_id, role_name, sso_stubber + ) + sso_cmd = sso_cmd_factory(session=session) + sso_cmd(args, parsed_globals) + + self.assert_do_sso_login_call( + mock_do_sso_login, + session, + expected_session_name=inputs.session_prompt.expected_default, + expected_sso_region=existing_sso_region, + expected_start_url=existing_start_url, + expected_scopes=parse_sso_registration_scopes(existing_scopes), + ) + assert_aws_config( + aws_config, + expected_lines=aws_config_lines_for_existing_profile_and_session, + ) + + def test_configure_sso_skips_account_role_config_when_no_access( + self, + sso_cmd, + ptk_stubber, + aws_config, + stub_sso_authorization_error, + mock_do_sso_login, + botocore_session, + args, + parsed_globals, + configure_sso_using_new_session_inputs, + capsys, + ): + inputs = configure_sso_using_new_session_inputs + inputs.skip_account_and_role_selection() + inputs.profile_prompt.expected_default = None + ptk_stubber.user_inputs = inputs + + stub_sso_authorization_error() + + sso_cmd(args, parsed_globals) + assert_aws_config( + aws_config, + expected_lines=[ + f"[profile {inputs.profile_prompt.answer}]", + f"sso_session = {inputs.session_prompt.answer}", + f"region = {inputs.region_prompt.answer}", + f"output = {inputs.output_prompt.answer}", + f"[sso-session {inputs.session_prompt.answer}]", + f"sso_start_url = {inputs.start_url_prompt.answer}", + f"sso_region = {inputs.sso_region_prompt.answer}", + f"sso_registration_scopes = {inputs.scopes_prompt.answer}", + ], + ) + stdout = capsys.readouterr().out + profile_answer = inputs.profile_prompt.answer + assert "Unable to list AWS accounts" in stdout + assert f"configured SSO for profile: {profile_answer}" in stdout + + def test_configure_sso_uses_profile_values_when_making_new_session( + self, + sso_cmd_factory, + ptk_stubber, + aws_config, + sso_stubber_factory, + stub_simple_single_item_sso_responses, + mock_do_sso_login, + args, + parsed_globals, + configure_sso_using_new_session_from_legacy_profile_inputs, + aws_config_lines_for_existing_legacy_profile, + account_id, + role_name, + existing_profile_name, + default_sso_scope, + ): + write_aws_config( + aws_config, lines=aws_config_lines_for_existing_legacy_profile + ) + session = StubbedSession(profile=existing_profile_name) + + inputs = configure_sso_using_new_session_from_legacy_profile_inputs + inputs.skip_account_and_role_selection() + ptk_stubber.user_inputs = inputs + + sso_stubber = sso_stubber_factory(session) + stub_simple_single_item_sso_responses( + account_id, role_name, sso_stubber + ) + sso_cmd = sso_cmd_factory(session=session) + sso_cmd(args, parsed_globals) + + self.assert_do_sso_login_call( + mock_do_sso_login, + session, + expected_session_name=inputs.session_prompt.answer, + expected_sso_region=inputs.sso_region_prompt.expected_default, + expected_start_url=inputs.start_url_prompt.expected_default, + expected_scopes=[default_sso_scope], + expected_force_refresh=True, + ) + assert_aws_config( + aws_config, + expected_lines=aws_config_lines_for_existing_legacy_profile + + [ + f"sso_session = {inputs.session_prompt.answer}", + f"[sso-session {inputs.session_prompt.answer}]", + f"sso_start_url = {inputs.start_url_prompt.expected_default}", + f"sso_region = {inputs.sso_region_prompt.expected_default}", + f"sso_registration_scopes = {default_sso_scope}", + ], + ) + + def test_configure_sso_suggests_values_from_sessions( + self, + sso_cmd_factory, + ptk_stubber, + aws_config, + sso_stubber_factory, + stub_simple_single_item_sso_responses, + aws_config_lines_for_existing_sso_session, + existing_sso_session, + existing_start_url, + args, + parsed_globals, + configure_sso_using_new_session_inputs, + account_id, + role_name, + all_sso_oidc_regions, + ): + write_aws_config( + aws_config, lines=aws_config_lines_for_existing_sso_session + ) + session = StubbedSession() + + inputs = configure_sso_using_new_session_inputs + inputs.skip_account_and_role_selection() + inputs.session_prompt.expected_completions = [existing_sso_session] + inputs.start_url_prompt.expected_completions = [existing_start_url] + inputs.sso_region_prompt.expected_completions = all_sso_oidc_regions + inputs.output_prompt.expected_completions = list( + CLI_OUTPUT_FORMATS.keys() + ) + ptk_stubber.user_inputs = inputs + + sso_stubber = sso_stubber_factory(session) + stub_simple_single_item_sso_responses( + account_id, role_name, sso_stubber + ) + sso_cmd = sso_cmd_factory(session=session) + assert sso_cmd(args, parsed_globals) == 0 + + +class TestConfigureSSOSessionCommand: + def test_new_sso_session( + self, + sso_session_cmd, + ptk_stubber, + aws_config, + configure_sso_session_inputs, + args, + parsed_globals, + capsys, + ): + inputs = configure_sso_session_inputs + ptk_stubber.user_inputs = inputs + sso_session_cmd(args, parsed_globals) + assert_aws_config( + aws_config, + expected_lines=[ + f"[sso-session {inputs.session_prompt.answer}]", + f"sso_start_url = {inputs.start_url_prompt.answer}", + f"sso_region = {inputs.sso_region_prompt.answer}", + f"sso_registration_scopes = {inputs.scopes_prompt.answer}", + ], + ) + expected_login = ( + f"aws sso login --sso-session {inputs.session_prompt.answer}" + ) + assert expected_login in capsys.readouterr().out + + def test_can_used_default_scope_for_new_session( + self, + sso_session_cmd, + ptk_stubber, + aws_config, + configure_sso_session_inputs, + args, + parsed_globals, + default_sso_scope, + ): + inputs = configure_sso_session_inputs + inputs.scopes_prompt.answer = "" + ptk_stubber.user_inputs = inputs + sso_session_cmd(args, parsed_globals) + assert_aws_config( + aws_config, + expected_lines=[ + f"[sso-session {inputs.session_prompt.answer}]", + f"sso_start_url = {inputs.start_url_prompt.answer}", + f"sso_region = {inputs.sso_region_prompt.answer}", + f"sso_registration_scopes = {default_sso_scope}", + ], + ) + + def test_reuse_existing_sso_session_configurations( + self, + sso_session_cmd_factory, + ptk_stubber, + aws_config, + aws_config_lines_for_existing_sso_session, + configure_sso_session_with_existing_defaults_inputs, + args, + parsed_globals, + existing_sso_session, + ): + write_aws_config( + aws_config, lines=aws_config_lines_for_existing_sso_session + ) + inputs = configure_sso_session_with_existing_defaults_inputs + inputs.clear_answers() + inputs.session_prompt.answer = existing_sso_session + ptk_stubber.user_inputs = inputs + + sso_session_cmd = sso_session_cmd_factory(session=StubbedSession()) + sso_session_cmd(args, parsed_globals) + assert_aws_config( + aws_config, expected_lines=aws_config_lines_for_existing_sso_session + ) + + def test_override_existing_sso_session_configurations( + self, + sso_session_cmd_factory, + ptk_stubber, + aws_config, + aws_config_lines_for_existing_sso_session, + configure_sso_session_with_existing_defaults_inputs, + args, + parsed_globals, + existing_sso_session, + ): + write_aws_config( + aws_config, lines=aws_config_lines_for_existing_sso_session + ) + inputs = configure_sso_session_with_existing_defaults_inputs + inputs.session_prompt.answer = existing_sso_session + ptk_stubber.user_inputs = inputs + + sso_session_cmd = sso_session_cmd_factory(session=StubbedSession()) + sso_session_cmd(args, parsed_globals) + assert_aws_config( + aws_config, + expected_lines=[ + f"[sso-session {existing_sso_session}]", + f"sso_start_url = {inputs.start_url_prompt.answer}", + f"sso_region = {inputs.sso_region_prompt.answer}", + f"sso_registration_scopes = {inputs.scopes_prompt.answer}", + ], + ) + class TestPTKPrompt(unittest.TestCase): def setUp(self): @@ -40,403 +1586,389 @@ def setUp(self): self.prompter = PTKPrompt(prompter=self.mock_prompter) def test_returns_input(self): - self.mock_prompter.return_value = 'new_value' - response = self.prompter.get_value('default_value', 'Prompt Text') - self.assertEqual(response, 'new_value') + self.mock_prompter.return_value = "new_value" + response = self.prompter.get_value("default_value", "Prompt Text") + self.assertEqual(response, "new_value") def test_user_hits_enter_returns_current(self): - self.mock_prompter.return_value = '' - response = self.prompter.get_value('default_value', 'Prompt Text') + self.mock_prompter.return_value = "" + response = self.prompter.get_value("default_value", "Prompt Text") # We convert the empty string to the default value - self.assertEqual(response, 'default_value') + self.assertEqual(response, "default_value") def assert_expected_completions(self, completions): # The order of the completion list can vary becuase it comes from the # dict's keys. Asserting that each expected completion is in the list _, kwargs = self.mock_prompter.call_args_list[0] - completer = kwargs['completer'] + completer = kwargs["completer"] self.assertEqual(len(completions), len(completer.words)) for completion in completions: self.assertIn(completion, completer.words) def assert_expected_meta_dict(self, meta_dict): _, kwargs = self.mock_prompter.call_args_list[0] - self.assertEqual(kwargs['completer'].meta_dict, meta_dict) + self.assertEqual(kwargs["completer"].meta_dict, meta_dict) def assert_expected_validator(self, validator): _, kwargs = self.mock_prompter.call_args_list[0] - self.assertEqual(kwargs['validator'], validator) + self.assertEqual(kwargs["validator"], validator) + + def assert_expected_toolbar(self, expected_toolbar): + _, kwargs = self.mock_prompter.call_args_list[0] + self.assertEqual(kwargs["bottom_toolbar"], expected_toolbar) + + def assert_expected_prompt_message(self, expected_message): + args, _ = self.mock_prompter.call_args_list[0] + self.assertEqual(args[0], expected_message) def test_handles_list_completions(self): - completions = ['a', 'b'] - self.prompter.get_value('', '', completions=completions) + completions = ["a", "b"] + self.prompter.get_value("", "", completions=completions) self.assert_expected_completions(completions) def test_handles_dict_completions(self): descriptions = { - 'a': 'the letter a', - 'b': 'the letter b', + "a": "the letter a", + "b": "the letter b", } - expected_completions = ['a', 'b'] - self.prompter.get_value('', '', completions=descriptions) + expected_completions = ["a", "b"] + self.prompter.get_value("", "", completions=descriptions) self.assert_expected_completions(expected_completions) self.assert_expected_meta_dict(descriptions) def test_passes_validator(self): validator = DummyValidator() - self.prompter.get_value('', '', validator=validator) + self.prompter.get_value("", "", validator=validator) self.assert_expected_validator(validator) def test_strips_extra_whitespace(self): - self.mock_prompter.return_value = ' no_whitespace \t ' - response = self.prompter.get_value('default_value', 'Prompt Text') - self.assertEqual(response, 'no_whitespace') + self.mock_prompter.return_value = " no_whitespace \t " + response = self.prompter.get_value("default_value", "Prompt Text") + self.assertEqual(response, "no_whitespace") + def test_can_provide_toolbar(self): + toolbar = "Toolbar content" + self.prompter.get_value("default_value", "Prompt Text", toolbar=toolbar) + self.assert_expected_toolbar(toolbar) -class TestStartUrlValidator(unittest.TestCase): - def setUp(self): - self.document = mock.Mock(spec=Document) - self.validator = StartUrlValidator() - - def _validate_text(self, text): - self.document.text = text - self.validator.validate(self.document) - - def assert_text_not_allowed(self, text): - with self.assertRaises(ValidationError): - self._validate_text(text) - - def test_disallowed_text(self): - not_start_urls = [ - '', - 'd-abc123', - 'foo bar baz', - ] - for text in not_start_urls: - self.assert_text_not_allowed(text) - - def test_allowed_text(self): - valid_start_urls = [ - 'https://d-abc123.awsapps.com/start', - 'https://d-abc123.awsapps.com/start#', - 'https://d-abc123.awsapps.com/start/', - 'https://d-abc123.awsapps.com/start-beta', - 'https://start.url', - ] - for text in valid_start_urls: - self._validate_text(text) - - def test_allows_empty_string_if_default(self): - default = 'https://some.default' - self.validator = StartUrlValidator(default) - self._validate_text('') - - -class TestConfigureSSOCommand(unittest.TestCase): - def setUp(self): - self.global_args = mock.Mock() - self._session = Session() - self.sso_client = self._session.create_client( - 'sso', - region_name='us-west-2', - ) - self.sso_stub = Stubber(self.sso_client) - self.profile = 'a-profile' - self.scoped_config = {} - self.full_config = { - 'profiles': { - self.profile: self.scoped_config - } - } - self.mock_session = mock.Mock(spec=Session) - self.mock_session.get_scoped_config.return_value = self.scoped_config - self.mock_session.emit_first_non_none_response.return_value = None - self.mock_session.full_config = self.full_config - self.mock_session.create_client.return_value = self.sso_client - self.mock_session.profile = self.profile - self.config_path = '/some/path' - self.session_config = { - 'config_file': self.config_path, - } - self.mock_session.get_config_variable = self.session_config.get - self.mock_session.get_available_regions.return_value = ['us-east-1'] - self.token_cache = {} - self.writer = mock.Mock(spec=ConfigFileWriter) - self.prompter = mock.Mock(spec=PTKPrompt) - self.selector = mock.Mock(spec=select_menu) - self.region = 'us-west-2' - self.output = 'json' - self.sso_region = 'us-east-1' - self.start_url = 'https://d-92671207e4.awsapps.com/start' - self.account_id = '0123456789' - self.role_name = 'roleA' - self.expires_at = datetime.now(tzlocal()) + timedelta(hours=24) - self.access_token = { - 'accessToken': 'access.token.string', - 'expiresAt': self.expires_at, - } - self.do_sso_login_mock = mock.Mock(spec=do_sso_login) - self.do_sso_login_mock.return_value = self.access_token - self.configure_sso = ConfigureSSOCommand( - self.mock_session, - prompter=self.prompter, - selector=self.selector, - config_writer=self.writer, - sso_token_cache=self.token_cache, - sso_login=self.do_sso_login_mock, - ) - - def _add_list_accounts_response(self, accounts): - params = { - 'accessToken': self.access_token['accessToken'], - } - response = { - 'accountList': accounts, - } - self.sso_stub.add_response('list_accounts', response, params) + def test_can_provide_prompt_format(self): + self.prompter.get_value( + "default_value", + "Prompt Text", + prompt_fmt="{prompt_text} [default: {current_value}]: ", + ) + self.assert_expected_prompt_message( + "Prompt Text [default: default_value]: " + ) + + +class TestSSOSessionConfigurationPrompter: + def get_toolbar_content(self, toolbar_render): + formatted_text = toolbar_render() + content_lines = [line for _, line in formatted_text] + return "".join(content_lines) - def _add_list_account_roles_response(self, roles): - params = { - 'accountId': self.account_id, - 'accessToken': self.access_token['accessToken'], + def test_prompt_for_session_name(self, sso_config_prompter, ptk_stubber): + ptk_stubber.user_inputs = UserInputs( + session_prompt=RequiredSessionPrompt("dev") + ) + assert sso_config_prompter.prompt_for_sso_session() == "dev" + assert sso_config_prompter.sso_session == "dev" + + def test_prompt_for_session_name_opt_out_of_required( + self, sso_config_prompter, ptk_stubber + ): + ptk_stubber.user_inputs = UserInputs( + session_prompt=RecommendedSessionPrompt("") + ) + answer = sso_config_prompter.prompt_for_sso_session(required=False) + assert answer is None + assert sso_config_prompter.sso_session is None + + def test_manually_set_session_name(self, sso_config_prompter): + sso_config_prompter.sso_session = "override" + assert sso_config_prompter.sso_session == "override" + + def test_setting_session_name_updates_sso_config( + self, + sso_config_prompter_factory, + aws_config, + aws_config_lines_for_existing_sso_session, + existing_sso_session, + existing_sso_region, + existing_start_url, + existing_scopes, + ): + write_aws_config(aws_config, aws_config_lines_for_existing_sso_session) + session = StubbedSession() + sso_config_prompter = sso_config_prompter_factory(session=session) + sso_config_prompter.sso_session = existing_sso_session + assert sso_config_prompter.sso_session_config == { + "sso_region": existing_sso_region, + "sso_start_url": existing_start_url, + "sso_registration_scopes": existing_scopes, } - response = { - 'roleList': roles, + + def test_prompt_for_session_suggests_existing_sessions( + self, + sso_config_prompter_factory, + ptk_stubber, + aws_config, + aws_config_lines_for_existing_sso_session, + existing_sso_session, + ): + write_aws_config(aws_config, aws_config_lines_for_existing_sso_session) + session = StubbedSession() + sso_config_prompter = sso_config_prompter_factory(session=session) + + ptk_stubber.user_inputs = UserInputs( + session_prompt=RequiredSessionPrompt( + "dev", expected_completions=[existing_sso_session] + ), + ) + assert sso_config_prompter.prompt_for_sso_session() == "dev" + + def test_prompt_for_session_name_shows_session_config_in_toolbar( + self, + sso_config_prompter_factory, + ptk_stubber, + mock_ptk_app, + aws_config, + aws_config_lines_for_existing_sso_session, + existing_sso_session, + existing_start_url, + existing_sso_region, + existing_scopes, + ): + write_aws_config(aws_config, aws_config_lines_for_existing_sso_session) + session = StubbedSession() + mock_ptk_prompt = mock.Mock(ptk_prompt) + prompter = PTKPrompt(mock_ptk_prompt) + sso_config_prompter = sso_config_prompter_factory( + session=session, + prompt=prompter, + ) + sso_config_prompter.prompt_for_sso_session() + toolbar_render = mock_ptk_prompt.call_args_list[0][1]["bottom_toolbar"] + mock_ptk_app.current_buffer.document.text = existing_sso_session + mock_ptk_app.output.get_size.return_value.columns = 1 + actual_toolbar_content = self.get_toolbar_content(toolbar_render) + expected_sso_config_in_toolbar = json.dumps( + { + "sso_start_url": existing_start_url, + "sso_region": existing_sso_region, + "sso_registration_scopes": existing_scopes, + }, + indent=2, + ) + assert expected_sso_config_in_toolbar in actual_toolbar_content + + def test_prompt_for_start_url(self, sso_config_prompter, ptk_stubber): + url = "https://start.here" + ptk_stubber.user_inputs = UserInputs( + start_url_prompt=StartUrlPrompt(url) + ) + assert sso_config_prompter.prompt_for_sso_start_url() == url + assert sso_config_prompter.sso_session_config == {"sso_start_url": url} + + def test_prompt_for_start_url_reuse_existing_configuration( + self, sso_config_prompter, ptk_stubber, existing_start_url + ): + sso_config_prompter.sso_session_config[ + "sso_start_url" + ] = existing_start_url + ptk_stubber.user_inputs = UserInputs( + start_url_prompt=StartUrlPrompt( + "", expected_default=existing_start_url + ) + ) + answer = sso_config_prompter.prompt_for_sso_start_url() + assert answer == existing_start_url + assert sso_config_prompter.sso_session_config == { + "sso_start_url": existing_start_url } - self.sso_stub.add_response('list_account_roles', response, params) - - def _add_prompt_responses(self): - self.prompter.get_value.side_effect = [ - self.start_url, - self.sso_region, - self.region, - self.output, - ] - - def _add_simple_single_item_responses(self): - selected_account = { - 'accountId': self.account_id, - 'emailAddress': 'account@site.com', + + def test_prompt_for_start_url_suggests_previously_used_start_urls( + self, + sso_config_prompter_factory, + ptk_stubber, + aws_config, + aws_config_lines_for_existing_sso_session, + existing_start_url, + ): + write_aws_config(aws_config, aws_config_lines_for_existing_sso_session) + session = StubbedSession() + url = "https://start.here" + ptk_stubber.user_inputs = UserInputs( + start_url_prompt=StartUrlPrompt( + answer=url, expected_completions=[existing_start_url] + ) + ) + sso_config_prompter = sso_config_prompter_factory(session=session) + answer = sso_config_prompter.prompt_for_sso_start_url() + assert answer == url + + def test_prompt_for_sso_region(self, sso_config_prompter, ptk_stubber): + sso_region = "us-west-2" + ptk_stubber.user_inputs = UserInputs( + sso_region_prompt=SSORegionPrompt(sso_region) + ) + assert sso_config_prompter.prompt_for_sso_region() == sso_region + assert sso_config_prompter.sso_session_config == { + "sso_region": sso_region } - self._add_list_accounts_response([selected_account]) - self._add_list_account_roles_response([{'roleName': self.role_name}]) - - def assert_config_updates(self, config=None): - if config is None: - config = { - '__section__': 'profile %s' % self.profile, - 'sso_start_url': self.start_url, - 'sso_region': self.sso_region, - 'sso_account_id': self.account_id, - 'sso_role_name': self.role_name, - 'region': self.region, - 'output': self.output, - } - self.writer.update_config.assert_called_with(config, self.config_path) - - def test_basic_configure_sso_flow(self): - self._add_prompt_responses() - selected_account = { - 'accountId': self.account_id, - 'emailAddress': 'account@site.com', + + def test_prompt_for_sso_region_reuse_existing_configuration( + self, sso_config_prompter, ptk_stubber, existing_sso_region + ): + sso_config_prompter.sso_session_config[ + "sso_region" + ] = existing_sso_region + ptk_stubber.user_inputs = UserInputs( + sso_region_prompt=SSORegionPrompt( + "", expected_default=existing_sso_region + ) + ) + answer = sso_config_prompter.prompt_for_sso_region() + assert answer == existing_sso_region + assert sso_config_prompter.sso_session_config == { + "sso_region": existing_sso_region } - self.selector.side_effect = [ - selected_account, - self.role_name, - ] - accounts = [ - selected_account, - {'accountId': '1234567890', 'emailAddress': 'account2@site.com'}, - ] - self._add_list_accounts_response(accounts) - roles = [ - {'roleName': self.role_name}, - {'roleName': 'roleB'}, - ] - self._add_list_account_roles_response(roles) - with self.sso_stub: - self.configure_sso(args=[], parsed_globals=self.global_args) - self.sso_stub.assert_no_pending_responses() - self.assert_config_updates() - - def test_single_account_single_role_flow_no_browser(self): - self._add_prompt_responses() - self._add_simple_single_item_responses() - with self.sso_stub: - self.configure_sso( - args=['--no-browser'], - parsed_globals=self.global_args, + + def test_prompt_for_sso_region_suggests_all_valid_sso_oidc_regions( + self, sso_config_prompter, ptk_stubber, all_sso_oidc_regions + ): + sso_region = "us-west-2" + ptk_stubber.user_inputs = UserInputs( + sso_region_prompt=SSORegionPrompt( + sso_region, expected_completions=all_sso_oidc_regions + ), + ) + assert sso_config_prompter.prompt_for_sso_region() == sso_region + + def test_prompt_for_scopes( + self, sso_config_prompter, ptk_stubber, default_sso_scope + ): + scopes = "scope-1, scope-2" + parsed_scopes = ["scope-1", "scope-2"] + ptk_stubber.user_inputs = UserInputs( + scopes_prompt=ScopesPrompt( + scopes, expected_default=default_sso_scope ) - self.sso_stub.assert_no_pending_responses() - self.assert_config_updates() - _, _, login_kwargs = self.do_sso_login_mock.mock_calls[0] - auth_handler = login_kwargs['on_pending_authorization'] - self.assertIsInstance(auth_handler, PrintOnlyHandler) - # Account / Role should be auto selected if only one is returned - self.assertEqual(self.selector.call_count, 0) - - def test_single_account_single_role_flow(self): - self._add_prompt_responses() - self._add_simple_single_item_responses() - with self.sso_stub: - self.configure_sso(args=[], parsed_globals=self.global_args) - self.sso_stub.assert_no_pending_responses() - self.assert_config_updates() - # Account / Role should be auto selected if only one is returned - self.assertEqual(self.selector.call_count, 0) - - def test_no_accounts_flow_raises_error(self): - self.prompter.get_value.side_effect = [self.start_url, self.sso_region] - self._add_list_accounts_response([]) - with self.assertRaises(RuntimeError): - with self.sso_stub: - self.configure_sso(args=[], parsed_globals=self.global_args) - self.sso_stub.assert_no_pending_responses() - - def test_no_roles_flow_raises_error(self): - self._add_prompt_responses() - selected_account = { - 'accountId': self.account_id, - 'emailAddress': 'account@site.com', + ) + answer = sso_config_prompter.prompt_for_sso_registration_scopes() + assert answer == parsed_scopes + assert sso_config_prompter.sso_session_config == { + "sso_registration_scopes": scopes } - self._add_list_accounts_response([selected_account]) - self._add_list_account_roles_response([]) - with self.assertRaises(RuntimeError): - with self.sso_stub: - self.configure_sso(args=[], parsed_globals=self.global_args) - self.sso_stub.assert_no_pending_responses() - - def assert_default_prompt_args(self, defaults): - calls = self.prompter.get_value.call_args_list - self.assertEqual(len(calls), len(defaults)) - for call, default in zip(calls, defaults): - # The default to the prompt call is the first positional param - self.assertEqual(call[0][0], default) - - def assert_prompt_completions(self, completions): - calls = self.prompter.get_value.call_args_list - self.assertEqual(len(calls), len(completions)) - for call, completions in zip(calls, completions): - _, kwargs = call - self.assertEqual(kwargs['completions'], completions) - - def test_defaults_to_scoped_config(self): - self.scoped_config['sso_start_url'] = 'default-url' - self.scoped_config['sso_region'] = 'default-sso-region' - self.scoped_config['region'] = 'default-region' - self.scoped_config['output'] = 'default-output' - self._add_prompt_responses() - self._add_simple_single_item_responses() - with self.sso_stub: - self.configure_sso(args=[], parsed_globals=self.global_args) - self.sso_stub.assert_no_pending_responses() - self.assert_config_updates() - expected_defaults = [ - 'default-url', - 'default-sso-region', - 'default-region', - 'default-output', - ] - self.assert_default_prompt_args(expected_defaults) - - def test_handles_no_profile(self): - expected_profile = 'profile-a' - self.profile = None - self.mock_session.profile = None - self.configure_sso = ConfigureSSOCommand( - self.mock_session, - prompter=self.prompter, - selector=self.selector, - config_writer=self.writer, - sso_token_cache=self.token_cache, - sso_login=self.do_sso_login_mock, - ) - # If there is no profile, it will be prompted for as the last value - self.prompter.get_value.side_effect = [ - self.start_url, - self.sso_region, - self.region, - self.output, - expected_profile, - ] - self._add_simple_single_item_responses() - with self.sso_stub: - self.configure_sso(args=[], parsed_globals=self.global_args) - self.sso_stub.assert_no_pending_responses() - self.profile = expected_profile - self.assert_config_updates() - - def test_handles_non_existant_profile(self): - not_found_exception = ProfileNotFound(profile=self.profile) - self.mock_session.get_scoped_config.side_effect = not_found_exception - self.configure_sso = ConfigureSSOCommand( - self.mock_session, - prompter=self.prompter, - selector=self.selector, - config_writer=self.writer, - sso_token_cache=self.token_cache, - sso_login=self.do_sso_login_mock, - ) - self._add_prompt_responses() - self._add_simple_single_item_responses() - with self.sso_stub: - self.configure_sso(args=[], parsed_globals=self.global_args) - self.sso_stub.assert_no_pending_responses() - self.assert_config_updates() - - def test_cli_config_is_none_not_written(self): - self.prompter.get_value.side_effect = [ - self.start_url, - self.sso_region, - # The CLI region and output format shouldn't be written - # to the config as they are None - None, - None - ] - self._add_simple_single_item_responses() - with self.sso_stub: - self.configure_sso(args=[], parsed_globals=self.global_args) - self.sso_stub.assert_no_pending_responses() - expected_config = { - '__section__': 'profile %s' % self.profile, - 'sso_start_url': self.start_url, - 'sso_region': self.sso_region, - 'sso_account_id': self.account_id, - 'sso_role_name': self.role_name, + + def test_prompt_for_scopes_reuse_existing_configuration( + self, sso_config_prompter, ptk_stubber, existing_scopes + ): + sso_config_prompter.sso_session_config[ + "sso_registration_scopes" + ] = existing_scopes + ptk_stubber.user_inputs = UserInputs( + scopes_prompt=ScopesPrompt("", expected_default=existing_scopes) + ) + answer = sso_config_prompter.prompt_for_sso_registration_scopes() + assert answer == parse_sso_registration_scopes(existing_scopes) + assert sso_config_prompter.sso_session_config == { + "sso_registration_scopes": existing_scopes } - self.assert_config_updates(config=expected_config) - def test_prompts_suggest_values(self): - self.full_config['profiles']['another_profile'] = { - 'sso_start_url': self.start_url, + def test_prompt_for_scopes_used_defaults_account_scope( + self, sso_config_prompter, ptk_stubber, default_sso_scope + ): + ptk_stubber.user_inputs = UserInputs( + scopes_prompt=ScopesPrompt("", expected_default=default_sso_scope) + ) + answer = sso_config_prompter.prompt_for_sso_registration_scopes() + assert answer == [default_sso_scope] + assert sso_config_prompter.sso_session_config == { + "sso_registration_scopes": default_sso_scope } - self._add_prompt_responses() - self._add_simple_single_item_responses() - with self.sso_stub: - self.configure_sso(args=[], parsed_globals=self.global_args) - self.sso_stub.assert_no_pending_responses() - expected_start_urls = [self.start_url] - expected_sso_regions = ['us-east-1'] - expected_cli_regions = None - expected_cli_outputs = list(CLI_OUTPUT_FORMATS.keys()) - expected_completions = [ - expected_start_urls, - expected_sso_regions, - expected_cli_regions, - expected_cli_outputs, - ] - self.assert_prompt_completions(expected_completions) + + def test_prompt_for_scopes_suggest_known_and_previously_used_scopes( + self, + sso_config_prompter_factory, + ptk_stubber, + aws_config, + aws_config_lines_for_existing_sso_session, + default_sso_scope, + existing_scopes, + ): + write_aws_config(aws_config, aws_config_lines_for_existing_sso_session) + session = StubbedSession() + ptk_stubber.user_inputs = UserInputs( + scopes_prompt=ScopesPrompt( + "", + expected_default=default_sso_scope, + expected_completions=[default_sso_scope] + + parse_sso_registration_scopes(existing_scopes), + ) + ) + sso_config_prompter = sso_config_prompter_factory(session=session) + answer = sso_config_prompter.prompt_for_sso_registration_scopes() + assert answer == [default_sso_scope] + + +def passes_validator(validator, text): + document = mock.Mock(spec=Document) + document.text = text + try: + validator.validate(document) + except ValidationError: + return False + return True + + +@pytest.mark.parametrize( + "validator_cls,input_value,default,is_valid", + [ + # StartUrlValidator cases + (StartUrlValidator, "https://d-abc123.awsapps.com/start", None, True), + (StartUrlValidator, "https://d-abc123.awsapps.com/start#", None, True), + (StartUrlValidator, "https://d-abc123.awsapps.com/start/", None, True), + ( + StartUrlValidator, + "https://d-abc123.awsapps.com/start-beta", + None, + True, + ), + (StartUrlValidator, "https://start.url", None, True), + (StartUrlValidator, "", "https://some.default", True), + (StartUrlValidator, "", None, False), + (StartUrlValidator, "d-abc123", None, False), + (StartUrlValidator, "foo bar baz", None, False), + # RequiredInputValidator cases + (RequiredInputValidator, "input-value", "default-value", True), + (RequiredInputValidator, "input-value", None, True), + (RequiredInputValidator, "", "default-value", True), + (RequiredInputValidator, "", None, False), + # ScopesValidator cases + (ScopesValidator, "sso:account:access", "sso:account:access", True), + (ScopesValidator, "", "sso:account:access", True), + (ScopesValidator, "value-1, value-2", None, True), + (ScopesValidator, " value-1, value-2 ", None, True), + (ScopesValidator, "value-1 value-2", None, False), + (ScopesValidator, "value-1, value-2 value3", None, False), + ], +) +def test_validators(validator_cls, input_value, default, is_valid): + validator = validator_cls(default) + assert passes_validator(validator, input_value) == is_valid class TestDisplayAccount(unittest.TestCase): def setUp(self): - self.account_id = '1234' - self.email_address = 'test@test.com' - self.account_name = 'FooBar' + self.account_id = "1234" + self.email_address = "test@test.com" + self.account_name = "FooBar" self.account = { - 'accountId': self.account_id, - 'emailAddress': self.email_address, - 'accountName': self.account_name, + "accountId": self.account_id, + "emailAddress": self.email_address, + "accountName": self.account_name, } def test_display_account_all_fields(self): @@ -446,22 +1978,22 @@ def test_display_account_all_fields(self): self.assertIn(self.account_id, account_str) def test_display_account_missing_email(self): - del self.account['emailAddress'] + del self.account["emailAddress"] account_str = display_account(self.account) self.assertIn(self.account_name, account_str) self.assertNotIn(self.email_address, account_str) self.assertIn(self.account_id, account_str) def test_display_account_missing_name(self): - del self.account['accountName'] + del self.account["accountName"] account_str = display_account(self.account) self.assertNotIn(self.account_name, account_str) self.assertIn(self.email_address, account_str) self.assertIn(self.account_id, account_str) def test_display_account_missing_name_and_email(self): - del self.account['accountName'] - del self.account['emailAddress'] + del self.account["accountName"] + del self.account["emailAddress"] account_str = display_account(self.account) self.assertNotIn(self.account_name, account_str) self.assertNotIn(self.email_address, account_str) diff --git a/tests/unit/customizations/sso/test_utils.py b/tests/unit/customizations/sso/test_utils.py index 0bc85a28f40d..47f5762b3d8c 100644 --- a/tests/unit/customizations/sso/test_utils.py +++ b/tests/unit/customizations/sso/test_utils.py @@ -12,6 +12,9 @@ # language governing permissions and limitations under the License. import os import webbrowser + +import pytest + from awscli.testutils import mock from awscli.testutils import unittest @@ -19,12 +22,30 @@ from botocore.exceptions import ClientError from awscli.compat import StringIO +from awscli.customizations.sso.utils import parse_sso_registration_scopes from awscli.customizations.sso.utils import do_sso_login from awscli.customizations.sso.utils import OpenBrowserHandler from awscli.customizations.sso.utils import PrintOnlyHandler from awscli.customizations.sso.utils import open_browser_with_original_ld_path +@pytest.mark.parametrize( + 'raw_scopes, parsed_scopes', + [ + ('scope', ['scope']), + (' scope ', ['scope']), + ('', []), + ('scope, ', ['scope']), + ('scope-1,scope-2', ['scope-1', 'scope-2']), + ('scope-1, scope-2', ['scope-1', 'scope-2']), + (' scope-1, scope-2 ', ['scope-1', 'scope-2']), + ('scope-1,scope-2,scope-3', ['scope-1', 'scope-2', 'scope-3']) + ] +) +def test_parse_registration_scopes(raw_scopes, parsed_scopes): + assert parse_sso_registration_scopes(raw_scopes) == parsed_scopes + + class TestDoSSOLogin(unittest.TestCase): def setUp(self): self.region = 'us-west-2' diff --git a/tests/utils/botocore/__init__.py b/tests/utils/botocore/__init__.py index 88cdb0635188..964ec88232d7 100644 --- a/tests/utils/botocore/__init__.py +++ b/tests/utils/botocore/__init__.py @@ -506,6 +506,14 @@ def __init__(self, *args, **kwargs): self._cached_clients = {} self._client_stubs = {} + @property + def cached_clients(self): + return self._cached_clients + + @property + def client_stubs(self): + return self._client_stubs + def create_client(self, service_name, *args, **kwargs): if service_name not in self._cached_clients: client = self._create_stubbed_client(service_name, *args, **kwargs) From cd346698ac3a73cf1b4741e5ccf360112211616c Mon Sep 17 00:00:00 2001 From: kyleknap Date: Tue, 25 Oct 2022 18:46:21 -0700 Subject: [PATCH 3/4] Add changelogs for new SSO session features --- .changes/next-release/enhancement-ssologin-96466.json | 5 +++++ .changes/next-release/feature-configuresso-52515.json | 5 +++++ .changes/next-release/feature-configuressosession-45599.json | 5 +++++ 3 files changed, 15 insertions(+) create mode 100644 .changes/next-release/enhancement-ssologin-96466.json create mode 100644 .changes/next-release/feature-configuresso-52515.json create mode 100644 .changes/next-release/feature-configuressosession-45599.json diff --git a/.changes/next-release/enhancement-ssologin-96466.json b/.changes/next-release/enhancement-ssologin-96466.json new file mode 100644 index 000000000000..18539a0cae21 --- /dev/null +++ b/.changes/next-release/enhancement-ssologin-96466.json @@ -0,0 +1,5 @@ +{ + "type": "enhancement", + "category": "``sso login``", + "description": "Add ``--sso-session`` argument to enable direct SSO login with a ``sso-session``" +} diff --git a/.changes/next-release/feature-configuresso-52515.json b/.changes/next-release/feature-configuresso-52515.json new file mode 100644 index 000000000000..8e92e962b726 --- /dev/null +++ b/.changes/next-release/feature-configuresso-52515.json @@ -0,0 +1,5 @@ +{ + "type": "feature", + "category": "``configure sso``", + "description": "Add support for configuring ``sso-session`` as part of configuring SSO-enabled profile" +} diff --git a/.changes/next-release/feature-configuressosession-45599.json b/.changes/next-release/feature-configuressosession-45599.json new file mode 100644 index 000000000000..5a0035f5451f --- /dev/null +++ b/.changes/next-release/feature-configuressosession-45599.json @@ -0,0 +1,5 @@ +{ + "type": "feature", + "category": "``configure sso-session``", + "description": "Add new ``configure sso-session`` command for creating and updating ``sso-session`` configurations" +} From 5f14a5f422baaf4e5147640e1b74e500aa55dc08 Mon Sep 17 00:00:00 2001 From: kyleknap Date: Thu, 17 Nov 2022 12:23:30 -0800 Subject: [PATCH 4/4] Update based on feedback --- awscli/customizations/configure/sso.py | 5 ++--- awscli/customizations/sso/utils.py | 3 +-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/awscli/customizations/configure/sso.py b/awscli/customizations/configure/sso.py index 19442634c53d..759d2cc832e2 100644 --- a/awscli/customizations/configure/sso.py +++ b/awscli/customizations/configure/sso.py @@ -50,7 +50,7 @@ _CMD_PROMPT_USAGE = ( 'To keep an existing value, hit enter when prompted for the value. When ' 'you are prompted for information, the current value will be displayed in ' - '[brackets]. If the config item has no value, it is displayed as ' + '[brackets]. If the config item has no value, it is displayed as ' '[None] or omitted entirely.\n\n' ) _CONFIG_EXTRA_INFO = ( @@ -98,7 +98,6 @@ def validate(self, document): document, 'Scope values must be separated by commas') def _is_comma_separated_list(self, value): - value.strip() scopes = value.split(',') for scope in scopes: if re.findall(r'\s', scope.strip()): @@ -470,7 +469,7 @@ def _prompt_for_profile(self, sso_account_id=None, sso_role_name=None): text = 'CLI profile name' default_profile = None if sso_account_id and sso_role_name: - default_profile = '{}-{}'.format(sso_role_name, sso_account_id) + default_profile = f'{sso_role_name}-{sso_account_id}' validator = RequiredInputValidator(default_profile) profile_name = self._prompter.get_value( default_profile, text, validator=validator) diff --git a/awscli/customizations/sso/utils.py b/awscli/customizations/sso/utils.py index 775379c04780..ae9a83e9e8b8 100644 --- a/awscli/customizations/sso/utils.py +++ b/awscli/customizations/sso/utils.py @@ -80,8 +80,7 @@ def do_sso_login(session, sso_region, start_url, token_cache=None, def parse_sso_registration_scopes(raw_scopes): parsed_scopes = [] for scope in raw_scopes.split(','): - scope = scope.strip() - if scope: + if scope := scope.strip(): parsed_scopes.append(scope) return parsed_scopes