diff --git a/awscli/customizations/assumerole.py b/awscli/customizations/assumerole.py new file mode 100644 index 000000000000..cd704dcff524 --- /dev/null +++ b/awscli/customizations/assumerole.py @@ -0,0 +1,276 @@ +import os +import time +import json +import logging + +from dateutil.parser import parse +from datetime import datetime +from dateutil.tz import tzlocal + +from botocore import credentials +from botocore.compat import total_seconds +from botocore.exceptions import PartialCredentialsError + + +LOG = logging.getLogger(__name__) + + +class InvalidConfigError(Exception): + pass + + +def register_assume_role_provider(event_handlers): + event_handlers.register('building-command-table.*', + inject_assume_role_provider, + unique_id='inject_assume_role_cred_provider') + + +def inject_assume_role_provider(session, event_name, **kwargs): + if event_name.endswith('.main'): + # Register the assume role provider only when we're not building the + # top level command table. We want all the top level args processed + # before we start injecting things into the session. + return + provider = create_assume_role_provider(session) + session.get_component('credential_provider').insert_before( + 'config-file', provider) + + +def create_assume_role_provider(session): + profile_name = session.get_config_variable('profile') or 'default' + load_config = lambda: session.full_config + return AssumeRoleProvider( + load_config=load_config, + client_creator=session.create_client, + cache=JSONFileCache(AssumeRoleProvider.CACHE_DIR), + profile_name=profile_name, + ) + + +def create_refresher_function(client, params): + def refresh(): + role_session_name = 'AWS-CLI-session-%s' % (int(time.time())) + params['RoleSessionName'] = role_session_name + response = client.assume_role(**params) + credentials = response['Credentials'] + # We need to normalize the credential names to + # the values expected by the refresh creds. + return { + 'access_key': credentials['AccessKeyId'], + 'secret_key': credentials['SecretAccessKey'], + 'token': credentials['SessionToken'], + 'expiry_time': credentials['Expiration'], + } + return refresh + + +class JSONFileCache(object): + """JSON file cache. + + This provides a dict like interface that stores JSON serializable + objects. + + The objects are serialized to JSON and stored in a file. These + values can be retrieved at a later time. + + """ + def __init__(self, working_dir): + self._working_dir = working_dir + + def __contains__(self, cache_key): + actual_key = self._convert_cache_key(cache_key) + return os.path.isfile(actual_key) + + def __getitem__(self, cache_key): + """Retrieve value from a cache key.""" + actual_key = self._convert_cache_key(cache_key) + try: + with open(actual_key) as f: + return json.load(f) + except (OSError, ValueError, IOError): + raise KeyError(cache_key) + + def __setitem__(self, cache_key, value): + full_key = self._convert_cache_key(cache_key) + try: + file_content = json.dumps(value) + except (TypeError, ValueError): + raise ValueError("Value cannot be cached, must be " + "JSON serializable: %s" % value) + if not os.path.isdir(self._working_dir): + os.makedirs(self._working_dir) + with open(full_key, 'w') as f: + f.write(file_content) + + def _convert_cache_key(self, cache_key): + full_path = os.path.join(self._working_dir, cache_key + '.json') + return full_path + + +class AssumeRoleProvider(credentials.CredentialProvider): + + METHOD = 'assume-role' + CACHE_DIR = os.path.expanduser(os.path.join('~', '.aws', 'cli', 'cache')) + ROLE_CONFIG_VAR = 'role_arn' + # Credentials are considered expired (and will be refreshed) once the total + # remaining time left until the credentials expires is less than the + # EXPIRY_WINDOW. + EXPIRY_WINDOW_SECONDS = 60 * 5 + + def __init__(self, load_config, client_creator, cache, profile_name): + """ + + :type load_config: callable + :param load_config: A function that accepts no arguments, and + when called, will return the full configuration dictionary + for the session (``session.full_config``). + + :type client_creator: callable + :param client_creator: A factory function that will create + a client when called. Has the same interface as + ``botocore.session.Session.create_client``. + + :type cache: JSONFileCache + :param cache: An object that supports ``__getitem__``, + ``__setitem__``, and ``__contains__``. An example + of this is the ``JSONFileCache`` class. + + :type profile_name: str + :param profile_name: The name of the profile. + + """ + self._load_config = load_config + # client_creator is a callable that creates function. + # It's basically session.create_client + self._client_creator = client_creator + self._profile_name = profile_name + self._cache = cache + # The _loaded_config attribute will be populated from the + # load_config() function once the configuration is actually + # loaded. The reason we go through all this instead of just + # requiring that the loaded_config be passed to us is to that + # we can defer configuration loaded until we actually try + # to load credentials (as opposed to when the object is + # instantiated). + self._loaded_config = {} + + def load(self): + self._loaded_config = self._load_config() + if self._has_assume_role_config_vars(): + return self._load_creds_via_assume_role() + + def _has_assume_role_config_vars(self): + profiles = self._loaded_config.get('profiles', {}) + return self.ROLE_CONFIG_VAR in profiles.get(self._profile_name, {}) + + def _load_creds_via_assume_role(self): + # We can get creds in one of two ways: + # * It can either be cached on disk from an pre-existing session + # * Cache doesn't have the creds (or is expired) so we need to make + # an assume role call to get temporary creds, which we then cache + # for subsequent requests. + creds = self._load_creds_from_cache() + if creds is not None: + LOG.debug("Credentials for role retrieved from cache.") + return creds + else: + # We get the Credential used by botocore as well + # as the original parsed response from the server. + creds, response = self._retrieve_temp_credentials() + cache_key = self._create_cache_key() + self._write_cached_credentials(response, cache_key) + return creds + + def _load_creds_from_cache(self): + cache_key = self._create_cache_key() + try: + from_cache = self._cache[cache_key] + if self._is_expired(from_cache): + # Don't need to delete the cache entry, + # when we refresh via AssumeRole, we'll + # update the cache with the new entry. + LOG.debug("Credentials were found in cache, but they are expired.") + return None + else: + return self._create_creds_from_response(from_cache) + except KeyError: + return None + + def _is_expired(self, credentials): + end_time = parse(credentials['Credentials']['Expiration']) + now = datetime.now(tzlocal()) + seconds = total_seconds(end_time - now) + return seconds < self.EXPIRY_WINDOW_SECONDS + + def _create_cache_key(self): + role_config = self._get_role_config_values() + cache_key = '%s--%s' % (self._profile_name, role_config['role_arn']) + return cache_key.replace('/', '-') + + def _write_cached_credentials(self, creds, cache_key): + self._cache[cache_key] = creds + + def _get_role_config_values(self): + # This returns the role related configuration. + profiles = self._loaded_config.get('profiles', {}) + try: + source_profile = profiles[self._profile_name]['source_profile'] + role_arn = profiles[self._profile_name]['role_arn'] + except KeyError as e: + raise PartialCredentialsError(provider=self.METHOD, + cred_var=str(e)) + external_id = profiles[self._profile_name].get('external_id') + if source_profile not in profiles: + raise InvalidConfigError( + 'The source_profile "%s" referenced in ' + 'the profile "%s" does not exist.' % ( + source_profile, self._profile_name)) + source_cred_values = profiles[source_profile] + return { + 'role_arn': role_arn, + 'external_id': external_id, + 'source_profile': source_profile, + 'source_cred_values': source_cred_values, + } + + + def _create_creds_from_response(self, response): + config = self._get_role_config_values() + return credentials.RefreshableCredentials( + access_key=response['Credentials']['AccessKeyId'], + secret_key=response['Credentials']['SecretAccessKey'], + token=response['Credentials']['SessionToken'], + method=self.METHOD, + expiry_time=parse(response['Credentials']['Expiration']), + refresh_using=create_refresher_function( + self._create_client_from_config(config), + self._assume_role_base_kwargs(config)), + ) + + def _create_client_from_config(self, config): + source_cred_values = config['source_cred_values'] + client = self._client_creator( + 'sts', aws_access_key_id=source_cred_values['aws_access_key_id'], + aws_secret_access_key=source_cred_values['aws_secret_access_key'], + aws_session_token=source_cred_values.get('aws_session_token'), + ) + return client + + def _assume_role_base_kwargs(self, config): + assume_role_kwargs = {'RoleArn': config['role_arn']} + if config['external_id'] is not None: + assume_role_kwargs['ExternalId'] = config['external_id'] + return assume_role_kwargs + + def _retrieve_temp_credentials(self): + LOG.debug("Retrieving credentials via AssumeRole.") + config = self._get_role_config_values() + client = self._create_client_from_config(config) + + assume_role_kwargs = self._assume_role_base_kwargs(config) + role_session_name = 'AWS-CLI-session-%s' % (int(time.time())) + assume_role_kwargs['RoleSessionName'] = role_session_name + + response = client.assume_role(**assume_role_kwargs) + creds = self._create_creds_from_response(response) + return creds, response diff --git a/awscli/handlers.py b/awscli/handlers.py index e9d2ce941f60..f7b36287b71a 100644 --- a/awscli/handlers.py +++ b/awscli/handlers.py @@ -51,6 +51,7 @@ from awscli.customizations.cliinputjson import register_cli_input_json from awscli.customizations.generatecliskeleton import \ register_generate_cli_skeleton +from awscli.customizations.assumerole import register_assume_role_provider def awscli_initialize(event_handlers): @@ -105,3 +106,4 @@ def awscli_initialize(event_handlers): register_cloudsearchdomain(event_handlers) register_s3_endpoint(event_handlers) register_generate_cli_skeleton(event_handlers) + register_assume_role_provider(event_handlers) diff --git a/tests/unit/customizations/test_assumerole.py b/tests/unit/customizations/test_assumerole.py new file mode 100644 index 000000000000..bb1a26510194 --- /dev/null +++ b/tests/unit/customizations/test_assumerole.py @@ -0,0 +1,270 @@ +# Copyright 2014 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import shutil +import tempfile +import os +from datetime import datetime, timedelta + +import mock +from botocore.hooks import HierarchicalEmitter +from botocore.exceptions import PartialCredentialsError +from dateutil.tz import tzlocal + +from awscli.testutils import unittest +from awscli.customizations import assumerole + + +class TestAssumeRolePlugin(unittest.TestCase): + def test_assume_role_provider_injected(self): + session = mock.Mock() + assumerole.inject_assume_role_provider( + session, event_name='building-command-table.foo') + + session.get_component.assert_called_with('credential_provider') + credential_provider = session.get_component.return_value + call_args = credential_provider.insert_before.call_args[0] + self.assertEqual(call_args[0], 'config-file') + self.assertIsInstance(call_args[1], assumerole.AssumeRoleProvider) + + def test_assume_role_provider_not_injected_for_main_command_table(self): + session = mock.Mock() + # When the main/top-level command table is created, it's emitted with + # an event name of building-command-table.main. We want to verify + # that the assumerole provider is not hooked up when that happens. + assumerole.inject_assume_role_provider( + session, event_name='building-command-table.main') + self.assertFalse(session.get_component.called) + + def test_assume_role_provider_registration(self): + event_handlers = HierarchicalEmitter() + assumerole.register_assume_role_provider(event_handlers) + session = mock.Mock() + event_handlers.emit('building-command-table.foo', session=session) + # Just verifying that anything on the session was called ensures + # that our handler was called, as it's the only thing that should + # be registered. + session.get_component.assert_called_with('credential_provider') + + +class TestAssumeRoleCredentialProvider(unittest.TestCase): + + maxDiff = None + + def setUp(self): + self.fake_config = { + 'profiles': { + 'development': { + 'role_arn': 'myrole', + 'source_profile': 'longterm', + }, + 'longterm': { + 'aws_access_key_id': 'akid', + 'aws_secret_access_key': 'skid', + } + } + } + + def create_config_loader(self, with_config=None): + if with_config is None: + with_config = self.fake_config + load_config = mock.Mock() + load_config.return_value = with_config + return load_config + + def create_client_creator(self, with_response): + # Create a mock sts client that returns a specific response + # for assume_role. + client = mock.Mock() + client.assume_role.return_value = with_response + return mock.Mock(return_value=client) + + def test_assume_role_with_no_cache(self): + response = { + 'Credentials': { + 'AccessKeyId': 'foo', + 'SecretAccessKey': 'bar', + 'SessionToken': 'baz', + 'Expiration': datetime.now(tzlocal()).isoformat() + }, + } + client_creator = self.create_client_creator(with_response=response) + provider = assumerole.AssumeRoleProvider( + self.create_config_loader(), + client_creator, cache={}, profile_name='development') + + credentials = provider.load() + + self.assertEqual(credentials.access_key, 'foo') + self.assertEqual(credentials.secret_key, 'bar') + self.assertEqual(credentials.token, 'baz') + + def test_assume_role_retrieves_from_cache(self): + date_in_future = datetime.utcnow() + timedelta(seconds=1000) + utc_timestamp = date_in_future.isoformat() + 'Z' + self.fake_config['profiles']['development']['role_arn'] = 'myrole' + cache = { + 'development--myrole': { + 'Credentials': { + 'AccessKeyId': 'foo-cached', + 'SecretAccessKey': 'bar-cached', + 'SessionToken': 'baz-cached', + 'Expiration': utc_timestamp, + } + } + } + provider = assumerole.AssumeRoleProvider( + self.create_config_loader(), mock.Mock(), + cache=cache, profile_name='development') + + credentials = provider.load() + + self.assertEqual(credentials.access_key, 'foo-cached') + self.assertEqual(credentials.secret_key, 'bar-cached') + self.assertEqual(credentials.token, 'baz-cached') + + def test_assume_role_in_cache_but_expired(self): + expired_creds = datetime.utcnow() + utc_timestamp = expired_creds.isoformat() + 'Z' + response = { + 'Credentials': { + 'AccessKeyId': 'foo', + 'SecretAccessKey': 'bar', + 'SessionToken': 'baz', + 'Expiration': utc_timestamp, + }, + } + client_creator = self.create_client_creator(with_response=response) + cache = { + 'development--myrole': { + 'Credentials': { + 'AccessKeyId': 'foo-cached', + 'SecretAccessKey': 'bar-cached', + 'SessionToken': 'baz-cached', + 'Expiration': utc_timestamp, + } + } + } + provider = assumerole.AssumeRoleProvider( + self.create_config_loader(), client_creator, + cache=cache, profile_name='development') + + credentials = provider.load() + + self.assertEqual(credentials.access_key, 'foo') + self.assertEqual(credentials.secret_key, 'bar') + self.assertEqual(credentials.token, 'baz') + + def test_external_id_provided(self): + self.fake_config['profiles']['development']['external_id'] = 'myid' + response = { + 'Credentials': { + 'AccessKeyId': 'foo', + 'SecretAccessKey': 'bar', + 'SessionToken': 'baz', + 'Expiration': datetime.now(tzlocal()).isoformat(), + }, + } + client_creator = self.create_client_creator(with_response=response) + provider = assumerole.AssumeRoleProvider( + self.create_config_loader(), + client_creator, cache={}, profile_name='development') + + provider.load() + + client = client_creator.return_value + client.assume_role.assert_called_with( + RoleArn='myrole', ExternalId='myid', RoleSessionName=mock.ANY) + + def test_no_config_is_noop(self): + self.fake_config['profiles']['development'] = { + 'aws_access_key_id': 'foo', + 'aws_secret_access_key': 'bar', + } + provider = assumerole.AssumeRoleProvider( + self.create_config_loader(), + mock.Mock(), cache={}, profile_name='development') + + # Because a role_arn was not specified, the AssumeRoleProvider + # is a noop and will not return credentials (which means we + # move on to the next provider). + credentials = provider.load() + self.assertIsNone(credentials) + + def test_source_profile_not_provided(self): + del self.fake_config['profiles']['development']['source_profile'] + provider = assumerole.AssumeRoleProvider( + self.create_config_loader(), + mock.Mock(), cache={}, profile_name='development') + + # source_profile is required, we shoudl get an error. + with self.assertRaises(PartialCredentialsError): + provider.load() + + def test_source_profile_does_not_exist(self): + dev_profile = self.fake_config['profiles']['development'] + dev_profile['source_profile'] = 'does-not-exist' + provider = assumerole.AssumeRoleProvider( + self.create_config_loader(), + mock.Mock(), cache={}, profile_name='development') + + # source_profile is required, we shoudl get an error. + with self.assertRaises(assumerole.InvalidConfigError): + provider.load() + + +class TestJSONCache(unittest.TestCase): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + self.cache = assumerole.JSONFileCache(self.tempdir) + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def test_supports_contains_check(self): + # By default the cache is empty because we're + # using a new temp dir everytime. + self.assertTrue('mykey' not in self.cache) + + def test_add_key_and_contains_check(self): + self.cache['mykey'] = {'foo': 'bar'} + self.assertTrue('mykey' in self.cache) + + def test_added_key_can_be_retrieved(self): + self.cache['mykey'] = {'foo': 'bar'} + self.assertEqual(self.cache['mykey'], {'foo': 'bar'}) + + def test_only_accepts_json_serializable_data(self): + with self.assertRaises(ValueError): + # set()'s cannot be serialized to a JSOn string. + self.cache['mykey'] = set() + + def test_can_override_existing_values(self): + self.cache['mykey'] = {'foo': 'bar'} + self.cache['mykey'] = {'baz': 'newvalue'} + self.assertEqual(self.cache['mykey'], {'baz': 'newvalue'}) + + def test_can_add_multiple_keys(self): + self.cache['mykey'] = {'foo': 'bar'} + self.cache['mykey2'] = {'baz': 'qux'} + self.assertEqual(self.cache['mykey'], {'foo': 'bar'}) + self.assertEqual(self.cache['mykey2'], {'baz': 'qux'}) + + def test_working_dir_does_not_exist(self): + working_dir = os.path.join(self.tempdir, 'foo') + cache = assumerole.JSONFileCache(working_dir) + cache['foo'] = {'bar': 'baz'} + self.assertEqual(cache['foo'], {'bar': 'baz'}) + + def test_key_error_raised_when_cache_key_does_not_exist(self): + with self.assertRaises(KeyError): + self.cache['foo']