Skip to content

Commit

Permalink
Port SSO credential provider updates
Browse files Browse the repository at this point in the history
The updates allow using sso-session configuration as part of
the SSO credential provider. This was ported from this PR:
boto/botocore#2812
  • Loading branch information
kyleknap committed Nov 16, 2022
1 parent 8fef0ef commit 7ba7b5f
Show file tree
Hide file tree
Showing 5 changed files with 306 additions and 78 deletions.
5 changes: 5 additions & 0 deletions .changes/next-release/enhancement-sso-81091.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "enhancement",
"category": "sso",
"description": "Add support for loading sso-session profiles for SSO credential provider"
}
152 changes: 77 additions & 75 deletions awscli/botocore/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@
UnauthorizedSSOTokenError,
UnknownCredentialError,
)
from botocore.tokens import SSOTokenProvider
from botocore.utils import (
ContainerMetadataFetcher,
FileWebIdentityTokenLoader,
InstanceMetadataFetcher,
JSONFileCache,
SSOTokenLoader,
original_ld_library_path,
parse_key_val_file,
Expand Down Expand Up @@ -211,6 +213,7 @@ def _create_sso_provider(self, profile_name):
profile_name=profile_name,
cache=self._cache,
token_cache=self._sso_token_cache,
token_provider=SSOTokenProvider(self._session),
)


Expand Down Expand Up @@ -282,57 +285,6 @@ def __call__(self):
return _Refresher(actual_refresh)


class JSONFileCache(object):
"""JSON file cache.
This provides a dict like interface that stores JSON serializable
objects.
The objects are serialized to JSON and stored in a file. These
values can be retrieved at a later time.
"""

CACHE_DIR = os.path.expanduser(os.path.join('~', '.aws', 'boto', 'cache'))

def __init__(self, working_dir=CACHE_DIR, dumps_func=None):
self._working_dir = working_dir
if dumps_func is None:
dumps_func = self._default_dumps
self._dumps = dumps_func

def _default_dumps(self, obj):
return json.dumps(obj, default=_serialize_if_needed)

def __contains__(self, cache_key):
actual_key = self._convert_cache_key(cache_key)
return os.path.isfile(actual_key)

def __getitem__(self, cache_key):
"""Retrieve value from a cache key."""
actual_key = self._convert_cache_key(cache_key)
try:
with open(actual_key) as f:
return json.load(f)
except (OSError, ValueError, IOError):
raise KeyError(cache_key)

def __setitem__(self, cache_key, value):
full_key = self._convert_cache_key(cache_key)
try:
file_content = self._dumps(value)
except (TypeError, ValueError):
raise ValueError("Value cannot be cached, must be "
"JSON serializable: %s" % value)
if not os.path.isdir(self._working_dir):
os.makedirs(self._working_dir)
with os.fdopen(os.open(full_key,
os.O_WRONLY | os.O_CREAT, 0o600), 'w') as f:
f.truncate()
f.write(file_content)

def _convert_cache_key(self, cache_key):
full_path = os.path.join(self._working_dir, cache_key + '.json')
return full_path


class Credentials(object):
"""
Holds the credentials needed to authenticate requests.
Expand Down Expand Up @@ -2002,13 +1954,16 @@ class SSOCredentialFetcher(CachedCredentialFetcher):

def __init__(self, start_url, sso_region, role_name, account_id,
client_creator, token_loader=None, cache=None,
expiry_window_seconds=None):
expiry_window_seconds=None, token_provider=None,
sso_session_name=None):
self._client_creator = client_creator
self._sso_region = sso_region
self._role_name = role_name
self._account_id = account_id
self._start_url = start_url
self._token_loader = token_loader
self._token_provider = token_provider
self._sso_session_name = sso_session_name
super(SSOCredentialFetcher, self).__init__(
cache, expiry_window_seconds
)
Expand All @@ -2019,10 +1974,13 @@ def _create_cache_key(self):
The cache key is intended to be compatible with file names.
"""
args = {
'startUrl': self._start_url,
'roleName': self._role_name,
'accountId': self._account_id,
}
if self._sso_session_name:
args['sessionName'] = self._sso_session_name
else:
args['startUrl'] = self._start_url
# NOTE: It would be good to hoist this cache key construction logic
# into the CachedCredentialFetcher class as we should be consistent.
# Unfortunately, the current assume role fetchers that sub class don't
Expand All @@ -2045,12 +2003,16 @@ def _get_credentials(self):
region_name=self._sso_region,
)
client = self._client_creator('sso', config=config)
if self._token_provider:
initial_token_data = self._token_provider.load_token()
token = initial_token_data.get_frozen_token().token
else:
token = self._token_loader(self._start_url)['accessToken']

token_dict = self._token_loader(self._start_url)
kwargs = {
'roleName': self._role_name,
'accountId': self._account_id,
'accessToken': token_dict['accessToken'],
'accessToken': token,
}
try:
response = client.get_role_credentials(**kwargs)
Expand All @@ -2076,18 +2038,24 @@ class SSOProvider(CredentialProvider):
_SSO_TOKEN_CACHE_DIR = os.path.expanduser(
os.path.join('~', '.aws', 'sso', 'cache')
)
_SSO_CONFIG_VARS = [
'sso_start_url',
'sso_region',
_PROFILE_REQUIRED_CONFIG_VARS = (
'sso_role_name',
'sso_account_id',
]
)
_SSO_REQUIRED_CONFIG_VARS = (
'sso_start_url',
'sso_region',
)
_ALL_REQUIRED_CONFIG_VARS = (
_PROFILE_REQUIRED_CONFIG_VARS + _SSO_REQUIRED_CONFIG_VARS
)

def __init__(self, load_config, client_creator, profile_name,
cache=None, token_cache=None):
cache=None, token_cache=None, token_provider=None):
if token_cache is None:
token_cache = JSONFileCache(self._SSO_TOKEN_CACHE_DIR)
self._token_cache = token_cache
self._token_provider = token_provider
if cache is None:
cache = {}
self.cache = cache
Expand All @@ -2100,17 +2068,24 @@ def _load_sso_config(self):
profiles = loaded_config.get('profiles', {})
profile_name = self._profile_name
profile_config = profiles.get(self._profile_name, {})
sso_sessions = loaded_config.get('sso_sessions', {})

# Role name & Account ID indicate the cred provider should be used
sso_cred_vars = ('sso_role_name', 'sso_account_id')
if all(c not in profile_config for c in sso_cred_vars):
if all(
c not in profile_config for c in self._PROFILE_REQUIRED_CONFIG_VARS
):
return None

resolved_config, extra_reqs = self._resolve_sso_session_reference(
profile_config, sso_sessions
)

config = {}
missing_config_vars = []
for config_var in self._SSO_CONFIG_VARS:
if config_var in profile_config:
config[config_var] = profile_config[config_var]
all_required_configs = self._ALL_REQUIRED_CONFIG_VARS + extra_reqs
for config_var in all_required_configs:
if config_var in resolved_config:
config[config_var] = resolved_config[config_var]
else:
missing_config_vars.append(config_var)

Expand All @@ -2122,23 +2097,50 @@ def _load_sso_config(self):
'required configuration: %s' % (profile_name, missing)
)
)

return config

def _resolve_sso_session_reference(self, profile_config, sso_sessions):
sso_session_name = profile_config.get('sso_session')
if sso_session_name is None:
# No reference to resolve, proceed with legacy flow
return profile_config, ()

if sso_session_name not in sso_sessions:
error_msg = f'The specified sso-session does not exist: "{sso_session_name}"'
raise InvalidConfigError(error_msg=error_msg)

config = profile_config.copy()
session = sso_sessions[sso_session_name]
for config_var, val in session.items():
# Validate any keys referenced in both profile and sso_session match
if config.get(config_var, val) != val:
error_msg = (
f"The value for {config_var} is inconsistent between "
f"profile ({config[config_var]}) and sso-session ({val})."
)
raise InvalidConfigError(error_msg=error_msg)
config[config_var] = val
return config, ('sso_session',)

def load(self):
sso_config = self._load_sso_config()
if not sso_config:
return None

sso_fetcher = SSOCredentialFetcher(
sso_config['sso_start_url'],
sso_config['sso_region'],
sso_config['sso_role_name'],
sso_config['sso_account_id'],
self._client_creator,
token_loader=SSOTokenLoader(cache=self._token_cache),
cache=self.cache,
)
fetcher_kwargs = {
'start_url': sso_config['sso_start_url'],
'sso_region': sso_config['sso_region'],
'role_name': sso_config['sso_role_name'],
'account_id': sso_config['sso_account_id'],
'client_creator': self._client_creator,
'token_loader': SSOTokenLoader(cache=self._token_cache),
'cache': self.cache,
}
if 'sso_session' in sso_config:
fetcher_kwargs['sso_session_name'] = sso_config['sso_session']
fetcher_kwargs['token_provider'] = self._token_provider

sso_fetcher = SSOCredentialFetcher(**fetcher_kwargs)

return DeferredRefreshableCredentials(
method=self.METHOD,
Expand Down
6 changes: 3 additions & 3 deletions awscli/botocore/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@
from botocore import UNSIGNED
from botocore.compat import total_seconds
from botocore.config import Config
from botocore.credentials import JSONFileCache
from botocore.exceptions import (
ClientError,
InvalidConfigError,
TokenRetrievalError,
)
from botocore.utils import CachedProperty, SSOTokenLoader
from botocore.utils import CachedProperty, JSONFileCache, SSOTokenLoader

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -184,11 +183,12 @@ class SSOTokenProvider:
"sso_region",
]
_GRANT_TYPE = "refresh_token"
DEFAULT_CACHE_CLS = JSONFileCache

def __init__(self, session, cache=None, time_fetcher=_utc_now):
self._session = session
if cache is None:
cache = JSONFileCache(
cache = self.DEFAULT_CACHE_CLS(
self._SSO_TOKEN_CACHE_DIR,
dumps_func=_sso_json_dumps,
)
Expand Down
70 changes: 70 additions & 0 deletions awscli/botocore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import time
import warnings
import weakref
from pathlib import Path

import botocore
import botocore.awsrequest
Expand Down Expand Up @@ -2887,3 +2888,72 @@ def _get_global_endpoint(self, endpoint, endpoint_variant_tags=None):
dns_suffix = self._DEFAULT_DNS_SUFFIX

return f"https://{endpoint}.endpoint.events.{dns_suffix}/"


class JSONFileCache:
"""JSON file cache.
This provides a dict like interface that stores JSON serializable
objects.
The objects are serialized to JSON and stored in a file. These
values can be retrieved at a later time.
"""

CACHE_DIR = os.path.expanduser(os.path.join('~', '.aws', 'boto', 'cache'))

def __init__(self, working_dir=CACHE_DIR, dumps_func=None):
self._working_dir = working_dir
if dumps_func is None:
dumps_func = self._default_dumps
self._dumps = dumps_func

def _default_dumps(self, obj):
return json.dumps(obj, default=self._serialize_if_needed)

def __contains__(self, cache_key):
actual_key = self._convert_cache_key(cache_key)
return os.path.isfile(actual_key)

def __getitem__(self, cache_key):
"""Retrieve value from a cache key."""
actual_key = self._convert_cache_key(cache_key)
try:
with open(actual_key) as f:
return json.load(f)
except (OSError, ValueError):
raise KeyError(cache_key)

def __delitem__(self, cache_key):
actual_key = self._convert_cache_key(cache_key)
try:
key_path = Path(actual_key)
key_path.unlink()
except FileNotFoundError:
raise KeyError(cache_key)

def __setitem__(self, cache_key, value):
full_key = self._convert_cache_key(cache_key)
try:
file_content = self._dumps(value)
except (TypeError, ValueError):
raise ValueError(
f"Value cannot be cached, must be "
f"JSON serializable: {value}"
)
if not os.path.isdir(self._working_dir):
os.makedirs(self._working_dir)
with os.fdopen(
os.open(full_key, os.O_WRONLY | os.O_CREAT, 0o600), 'w'
) as f:
f.truncate()
f.write(file_content)

def _convert_cache_key(self, cache_key):
full_path = os.path.join(self._working_dir, cache_key + '.json')
return full_path

def _serialize_if_needed(self, value, iso=False):
if isinstance(value, datetime.datetime):
if iso:
return value.isoformat()
return value.strftime('%Y-%m-%dT%H:%M:%S%Z')
return value
Loading

0 comments on commit 7ba7b5f

Please sign in to comment.