Skip to content

Commit

Permalink
Add support for MFA when assuming a role
Browse files Browse the repository at this point in the history
Feedback from aws#990, this adds support for MFA when assuming a role.
To enable this, in addition to role_arn and source_profile, you can
specify an mfa_serial option in your config file::

    [profile foo]
    role_arn = ...
    source_profile = development
    mfa_serial = .....

This is the the mfa arn/device id.  If an mfa_serial is
provided then a user will be prompted for the token code when
the AssumeRole call happens.

As mentioned in the original PR, for now when the temporary
credentials expire, an exception will be raised if MFA is
required.  We can look into updating this in the future to support
reprompting the user.  This only affects the case where the
credentials expire within the duration of the AWS CLI process.
Aside from some of the ``aws s3 cp/sync`` commands, the AWS CLI
is generally a short lived process so this won't affect the
common usage scenarios.
  • Loading branch information
jamesls committed Nov 10, 2014
1 parent b854792 commit e8d9791
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 19 deletions.
73 changes: 55 additions & 18 deletions awscli/customizations/assumerole.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
import json
import logging
import getpass

from dateutil.parser import parse
from datetime import datetime
Expand All @@ -19,6 +20,10 @@ class InvalidConfigError(Exception):
pass


class RefreshWithMFAUnsupportedError(Exception):
pass


def register_assume_role_provider(event_handlers):
event_handlers.register('building-command-table.*',
inject_assume_role_provider,
Expand All @@ -31,21 +36,26 @@ def inject_assume_role_provider(session, event_name, **kwargs):
# top level command table. We want all the top level args processed
# before we start injecting things into the session.
return
provider = create_assume_role_provider(session)
provider = create_assume_role_provider(session, AssumeRoleProvider)
try:
session.get_component('credential_provider').insert_before(
'config-file', provider)
# The final order will be:
# * env
# * assume-role
# * shared-credentials-file
# * ...
cred_chain = session.get_component('credential_provider')
cred_chain.insert_before('shared-credentials-file', provider)
except Exception:
# This is ok, it just means that we couldn't create the credential
# provider object.
LOG.debug("Not registering assume-role provider, credential "
"provider from session could not be created.")


def create_assume_role_provider(session):
def create_assume_role_provider(session, provider_cls):
profile_name = session.get_config_variable('profile') or 'default'
load_config = lambda: session.full_config
return AssumeRoleProvider(
return provider_cls(
load_config=load_config,
client_creator=session.create_client,
cache=JSONFileCache(AssumeRoleProvider.CACHE_DIR),
Expand All @@ -70,6 +80,16 @@ def refresh():
return refresh


def create_mfa_serial_refresh():
def _refresher():
# We can explore an option in the future to support
# reprompting for MFA, but for now we just error out
# when the temp creds expire.
raise RefreshWithMFAUnsupportedError(
"Cannot refresh credentials: MFA token required.")
return _refresher


class JSONFileCache(object):
"""JSON file cache.
Expand Down Expand Up @@ -121,9 +141,10 @@ class AssumeRoleProvider(credentials.CredentialProvider):
# Credentials are considered expired (and will be refreshed) once the total
# remaining time left until the credentials expires is less than the
# EXPIRY_WINDOW.
EXPIRY_WINDOW_SECONDS = 60 * 5
EXPIRY_WINDOW_SECONDS = 60 * 15

def __init__(self, load_config, client_creator, cache, profile_name):
def __init__(self, load_config, client_creator, cache, profile_name,
prompter=getpass.getpass):
"""
:type load_config: callable
Expand All @@ -144,13 +165,18 @@ def __init__(self, load_config, client_creator, cache, profile_name):
:type profile_name: str
:param profile_name: The name of the profile.
:type prompter: callable
:param prompter: A callable that returns input provided
by the user (i.e raw_input, getpass.getpass, etc.).
"""
self._load_config = load_config
# client_creator is a callable that creates function.
# It's basically session.create_client
self._client_creator = client_creator
self._profile_name = profile_name
self._cache = cache
self._prompter = prompter
# The _loaded_config attribute will be populated from the
# load_config() function once the configuration is actually
# loaded. The reason we go through all this instead of just
Expand Down Expand Up @@ -222,6 +248,7 @@ def _get_role_config_values(self):
try:
source_profile = profiles[self._profile_name]['source_profile']
role_arn = profiles[self._profile_name]['role_arn']
mfa_serial = profiles[self._profile_name].get('mfa_serial')
except KeyError as e:
raise PartialCredentialsError(provider=self.METHOD,
cred_var=str(e))
Expand All @@ -236,22 +263,28 @@ def _get_role_config_values(self):
'role_arn': role_arn,
'external_id': external_id,
'source_profile': source_profile,
'mfa_serial': mfa_serial,
'source_cred_values': source_cred_values,
}


def _create_creds_from_response(self, response):
config = self._get_role_config_values()
if config.get('mfa_serial') is not None:
# MFA would require getting a new TokenCode which would require
# prompting the user for a new token, so we use a different
# refresh_func.
refresh_func = create_mfa_serial_refresh()
else:
refresh_func = create_refresher_function(
self._create_client_from_config(config),
self._assume_role_base_kwargs(config))
return credentials.RefreshableCredentials(
access_key=response['Credentials']['AccessKeyId'],
secret_key=response['Credentials']['SecretAccessKey'],
token=response['Credentials']['SessionToken'],
method=self.METHOD,
expiry_time=parse(response['Credentials']['Expiration']),
refresh_using=create_refresher_function(
self._create_client_from_config(config),
self._assume_role_base_kwargs(config)),
)
refresh_using=refresh_func)

def _create_client_from_config(self, config):
source_cred_values = config['source_cred_values']
Expand All @@ -262,12 +295,6 @@ def _create_client_from_config(self, config):
)
return client

def _assume_role_base_kwargs(self, config):
assume_role_kwargs = {'RoleArn': config['role_arn']}
if config['external_id'] is not None:
assume_role_kwargs['ExternalId'] = config['external_id']
return assume_role_kwargs

def _retrieve_temp_credentials(self):
LOG.debug("Retrieving credentials via AssumeRole.")
config = self._get_role_config_values()
Expand All @@ -280,3 +307,13 @@ def _retrieve_temp_credentials(self):
response = client.assume_role(**assume_role_kwargs)
creds = self._create_creds_from_response(response)
return creds, response

def _assume_role_base_kwargs(self, config):
assume_role_kwargs = {'RoleArn': config['role_arn']}
if config['external_id'] is not None:
assume_role_kwargs['ExternalId'] = config['external_id']
if config['mfa_serial'] is not None:
token_code = self._prompter("Enter MFA code: ")
assume_role_kwargs['SerialNumber'] = config['mfa_serial']
assume_role_kwargs['TokenCode'] = token_code
return assume_role_kwargs
62 changes: 61 additions & 1 deletion tests/unit/customizations/test_assumerole.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_assume_role_provider_injected(self):
session.get_component.assert_called_with('credential_provider')
credential_provider = session.get_component.return_value
call_args = credential_provider.insert_before.call_args[0]
self.assertEqual(call_args[0], 'config-file')
self.assertEqual(call_args[0], 'shared-credentials-file')
self.assertIsInstance(call_args[1], assumerole.AssumeRoleProvider)

def test_assume_role_provider_not_injected_for_main_command_table(self):
Expand Down Expand Up @@ -194,6 +194,66 @@ def test_external_id_provided(self):
client.assume_role.assert_called_with(
RoleArn='myrole', ExternalId='myid', RoleSessionName=mock.ANY)

def test_assume_role_with_mfa(self):
self.fake_config['profiles']['development']['mfa_serial'] = 'mfa'
response = {
'Credentials': {
'AccessKeyId': 'foo',
'SecretAccessKey': 'bar',
'SessionToken': 'baz',
'Expiration': datetime.now(tzlocal()).isoformat(),
},
}
client_creator = self.create_client_creator(with_response=response)
prompter = mock.Mock(return_value='token-code')
provider = assumerole.AssumeRoleProvider(
self.create_config_loader(), client_creator,
cache={}, profile_name='development', prompter=prompter)

provider.load()

client = client_creator.return_value
# In addition to the normal assume role args, we should also
# inject the serial number from the config as well as the
# token code that comes from prompting the user (the prompter
# object).
client.assume_role.assert_called_with(
RoleArn='myrole', RoleSessionName=mock.ANY, SerialNumber='mfa',
TokenCode='token-code')

def test_assume_role_mfa_cannot_refresh_credentials(self):
# Note: we should look into supporting optional behavior
# in the future that allows for reprompting for credentials.
# But for now, if we get temp creds with MFA then when those
# creds expire, we can't refresh the credentials.
self.fake_config['profiles']['development']['mfa_serial'] = 'mfa'
response = {
'Credentials': {
'AccessKeyId': 'foo',
'SecretAccessKey': 'bar',
'SessionToken': 'baz',
# We're creating an expiry time in the past so as
# soon as we try to access the credentials, the
# refresh behavior will be triggered.
'Expiration': (
datetime.now(tzlocal()) -
timedelta(seconds=100)).isoformat(),
},
}
client_creator = self.create_client_creator(with_response=response)
provider = assumerole.AssumeRoleProvider(
self.create_config_loader(), client_creator,
cache={}, profile_name='development',
prompter=mock.Mock(return_value='token-code'))

creds = provider.load()
with self.assertRaises(assumerole.RefreshWithMFAUnsupportedError):
# access_key is a property that will refresh credentials
# if they're expired. Because we set the expiry time to
# something in the past, this will trigger the refresh
# behavior, with with MFA will currently raise an exception.
creds.access_key

def test_no_config_is_noop(self):
self.fake_config['profiles']['development'] = {
'aws_access_key_id': 'foo',
Expand Down

0 comments on commit e8d9791

Please sign in to comment.