Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sso-session credentials #2812

Merged
merged 1 commit into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changes/next-release/enhancement-sso-60997.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "enhancement",
"category": "sso",
"description": "Add support for loading sso-session profiles from the aws config"
}
162 changes: 77 additions & 85 deletions botocore/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from collections import namedtuple
from copy import deepcopy
from hashlib import sha1
from pathlib import Path

from dateutil.parser import parse
from dateutil.tz import tzlocal, tzutc
Expand All @@ -43,10 +42,12 @@
UnauthorizedSSOTokenError,
UnknownCredentialError,
)
from botocore.tokens import SSOTokenProvider
from botocore.utils import (
ContainerMetadataFetcher,
FileWebIdentityTokenLoader,
InstanceMetadataFetcher,
JSONFileCache,
SSOTokenLoader,
parse_key_val_file,
resolve_imds_endpoint_mode,
Expand Down Expand Up @@ -223,6 +224,7 @@ def _create_sso_provider(self, profile_name):
profile_name=profile_name,
cache=self._cache,
token_cache=self._sso_token_cache,
token_provider=SSOTokenProvider(self._session),
)


Expand Down Expand Up @@ -292,68 +294,6 @@ def __call__(self):
return _Refresher(actual_refresh)


class JSONFileCache:
"""JSON file cache.
This provides a dict like interface that stores JSON serializable
objects.
The objects are serialized to JSON and stored in a file. These
values can be retrieved at a later time.
"""

CACHE_DIR = os.path.expanduser(os.path.join('~', '.aws', 'boto', 'cache'))

def __init__(self, working_dir=CACHE_DIR, dumps_func=None):
self._working_dir = working_dir
if dumps_func is None:
dumps_func = self._default_dumps
self._dumps = dumps_func

def _default_dumps(self, obj):
return json.dumps(obj, default=_serialize_if_needed)

def __contains__(self, cache_key):
actual_key = self._convert_cache_key(cache_key)
return os.path.isfile(actual_key)

def __getitem__(self, cache_key):
"""Retrieve value from a cache key."""
actual_key = self._convert_cache_key(cache_key)
try:
with open(actual_key) as f:
return json.load(f)
except (OSError, ValueError):
raise KeyError(cache_key)

def __delitem__(self, cache_key):
actual_key = self._convert_cache_key(cache_key)
try:
key_path = Path(actual_key)
key_path.unlink()
except FileNotFoundError:
raise KeyError(cache_key)

def __setitem__(self, cache_key, value):
full_key = self._convert_cache_key(cache_key)
try:
file_content = self._dumps(value)
except (TypeError, ValueError):
raise ValueError(
f"Value cannot be cached, must be "
f"JSON serializable: {value}"
)
if not os.path.isdir(self._working_dir):
os.makedirs(self._working_dir)
with os.fdopen(
os.open(full_key, os.O_WRONLY | os.O_CREAT, 0o600), 'w'
) as f:
f.truncate()
f.write(file_content)

def _convert_cache_key(self, cache_key):
full_path = os.path.join(self._working_dir, cache_key + '.json')
return full_path


class Credentials:
"""
Holds the credentials needed to authenticate requests.
Expand Down Expand Up @@ -2118,13 +2058,17 @@ def __init__(
token_loader=None,
cache=None,
expiry_window_seconds=None,
token_provider=None,
sso_session_name=None,
):
self._client_creator = client_creator
self._sso_region = sso_region
self._role_name = role_name
self._account_id = account_id
self._start_url = start_url
self._token_loader = token_loader
self._token_provider = token_provider
self._sso_session_name = sso_session_name
super().__init__(cache, expiry_window_seconds)

def _create_cache_key(self):
Expand All @@ -2133,10 +2077,13 @@ def _create_cache_key(self):
The cache key is intended to be compatible with file names.
"""
args = {
'startUrl': self._start_url,
'roleName': self._role_name,
'accountId': self._account_id,
}
if self._sso_session_name:
args['sessionName'] = self._sso_session_name
else:
args['startUrl'] = self._start_url
# NOTE: It would be good to hoist this cache key construction logic
# into the CachedCredentialFetcher class as we should be consistent.
# Unfortunately, the current assume role fetchers that sub class don't
Expand All @@ -2159,12 +2106,16 @@ def _get_credentials(self):
region_name=self._sso_region,
)
client = self._client_creator('sso', config=config)
if self._token_provider:
initial_token_data = self._token_provider.load_token()
token = initial_token_data.get_frozen_token().token
else:
token = self._token_loader(self._start_url)['accessToken']

token_dict = self._token_loader(self._start_url)
kwargs = {
'roleName': self._role_name,
'accountId': self._account_id,
'accessToken': token_dict['accessToken'],
'accessToken': token,
}
try:
response = client.get_role_credentials(**kwargs)
Expand All @@ -2190,12 +2141,17 @@ class SSOProvider(CredentialProvider):
_SSO_TOKEN_CACHE_DIR = os.path.expanduser(
os.path.join('~', '.aws', 'sso', 'cache')
)
_SSO_CONFIG_VARS = [
'sso_start_url',
'sso_region',
_PROFILE_REQUIRED_CONFIG_VARS = (
'sso_role_name',
'sso_account_id',
]
)
_SSO_REQUIRED_CONFIG_VARS = (
'sso_start_url',
'sso_region',
)
_ALL_REQUIRED_CONFIG_VARS = (
_PROFILE_REQUIRED_CONFIG_VARS + _SSO_REQUIRED_CONFIG_VARS
)

def __init__(
self,
Expand All @@ -2204,10 +2160,12 @@ def __init__(
profile_name,
cache=None,
token_cache=None,
token_provider=None,
):
if token_cache is None:
token_cache = JSONFileCache(self._SSO_TOKEN_CACHE_DIR)
self._token_cache = token_cache
self._token_provider = token_provider
if cache is None:
cache = {}
self.cache = cache
Expand All @@ -2220,17 +2178,24 @@ def _load_sso_config(self):
profiles = loaded_config.get('profiles', {})
profile_name = self._profile_name
profile_config = profiles.get(self._profile_name, {})
sso_sessions = loaded_config.get('sso_sessions', {})

# Role name & Account ID indicate the cred provider should be used
sso_cred_vars = ('sso_role_name', 'sso_account_id')
if all(c not in profile_config for c in sso_cred_vars):
if all(
c not in profile_config for c in self._PROFILE_REQUIRED_CONFIG_VARS
):
return None

resolved_config, extra_reqs = self._resolve_sso_session_reference(
profile_config, sso_sessions
)

config = {}
missing_config_vars = []
for config_var in self._SSO_CONFIG_VARS:
if config_var in profile_config:
config[config_var] = profile_config[config_var]
all_required_configs = self._ALL_REQUIRED_CONFIG_VARS + extra_reqs
for config_var in all_required_configs:
if config_var in resolved_config:
config[config_var] = resolved_config[config_var]
else:
missing_config_vars.append(config_var)

Expand All @@ -2242,23 +2207,50 @@ def _load_sso_config(self):
'required configuration: %s' % (profile_name, missing)
)
)

return config

def _resolve_sso_session_reference(self, profile_config, sso_sessions):
sso_session_name = profile_config.get('sso_session')
if sso_session_name is None:
# No reference to resolve, proceed with legacy flow
return profile_config, ()

if sso_session_name not in sso_sessions:
error_msg = f'The specified sso-session does not exist: "{sso_session_name}"'
raise InvalidConfigError(error_msg=error_msg)

config = profile_config.copy()
session = sso_sessions[sso_session_name]
for config_var, val in session.items():
# Validate any keys referenced in both profile and sso_session match
if config.get(config_var, val) != val:
error_msg = (
f"The value for {config_var} is inconsistent between "
f"profile ({config[config_var]}) and sso-session ({val})."
)
raise InvalidConfigError(error_msg=error_msg)
config[config_var] = val
return config, ('sso_session',)

def load(self):
sso_config = self._load_sso_config()
if not sso_config:
return None

sso_fetcher = SSOCredentialFetcher(
sso_config['sso_start_url'],
sso_config['sso_region'],
sso_config['sso_role_name'],
sso_config['sso_account_id'],
self._client_creator,
token_loader=SSOTokenLoader(cache=self._token_cache),
cache=self.cache,
)
fetcher_kwargs = {
'start_url': sso_config['sso_start_url'],
'sso_region': sso_config['sso_region'],
'role_name': sso_config['sso_role_name'],
'account_id': sso_config['sso_account_id'],
'client_creator': self._client_creator,
'token_loader': SSOTokenLoader(cache=self._token_cache),
'cache': self.cache,
}
if 'sso_session' in sso_config:
fetcher_kwargs['sso_session_name'] = sso_config['sso_session']
fetcher_kwargs['token_provider'] = self._token_provider

sso_fetcher = SSOCredentialFetcher(**fetcher_kwargs)

return DeferredRefreshableCredentials(
method=self.METHOD,
Expand Down
6 changes: 3 additions & 3 deletions botocore/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@
from botocore import UNSIGNED
from botocore.compat import total_seconds
from botocore.config import Config
from botocore.credentials import JSONFileCache
from botocore.exceptions import (
ClientError,
InvalidConfigError,
TokenRetrievalError,
)
from botocore.utils import CachedProperty, SSOTokenLoader
from botocore.utils import CachedProperty, JSONFileCache, SSOTokenLoader

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -182,11 +181,12 @@ class SSOTokenProvider:
"sso_region",
]
_GRANT_TYPE = "refresh_token"
DEFAULT_CACHE_CLS = JSONFileCache

def __init__(self, session, cache=None, time_fetcher=_utc_now):
self._session = session
if cache is None:
cache = JSONFileCache(
cache = self.DEFAULT_CACHE_CLS(
self._SSO_TOKEN_CACHE_DIR,
dumps_func=_sso_json_dumps,
)
Expand Down
70 changes: 70 additions & 0 deletions botocore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import time
import warnings
import weakref
from pathlib import Path

import dateutil.parser
from dateutil.tz import tzutc
Expand Down Expand Up @@ -3217,3 +3218,72 @@ def is_s3_accelerate_url(url):

# Remaining parts must all be in the whitelist.
return all(p in S3_ACCELERATE_WHITELIST for p in feature_parts)


class JSONFileCache:
"""JSON file cache.
This provides a dict like interface that stores JSON serializable
objects.
The objects are serialized to JSON and stored in a file. These
values can be retrieved at a later time.
"""

CACHE_DIR = os.path.expanduser(os.path.join('~', '.aws', 'boto', 'cache'))

def __init__(self, working_dir=CACHE_DIR, dumps_func=None):
self._working_dir = working_dir
if dumps_func is None:
dumps_func = self._default_dumps
self._dumps = dumps_func

def _default_dumps(self, obj):
return json.dumps(obj, default=self._serialize_if_needed)

def __contains__(self, cache_key):
actual_key = self._convert_cache_key(cache_key)
return os.path.isfile(actual_key)

def __getitem__(self, cache_key):
"""Retrieve value from a cache key."""
actual_key = self._convert_cache_key(cache_key)
try:
with open(actual_key) as f:
return json.load(f)
except (OSError, ValueError):
raise KeyError(cache_key)

def __delitem__(self, cache_key):
actual_key = self._convert_cache_key(cache_key)
try:
key_path = Path(actual_key)
key_path.unlink()
except FileNotFoundError:
raise KeyError(cache_key)

def __setitem__(self, cache_key, value):
full_key = self._convert_cache_key(cache_key)
try:
file_content = self._dumps(value)
except (TypeError, ValueError):
raise ValueError(
f"Value cannot be cached, must be "
f"JSON serializable: {value}"
)
if not os.path.isdir(self._working_dir):
os.makedirs(self._working_dir)
with os.fdopen(
os.open(full_key, os.O_WRONLY | os.O_CREAT, 0o600), 'w'
) as f:
f.truncate()
f.write(file_content)

def _convert_cache_key(self, cache_key):
full_path = os.path.join(self._working_dir, cache_key + '.json')
return full_path

def _serialize_if_needed(self, value, iso=False):
if isinstance(value, datetime.datetime):
if iso:
return value.isoformat()
return value.strftime('%Y-%m-%dT%H:%M:%S%Z')
return value
Loading