Skip to content

Commit

Permalink
[MSAL]Add AAD error handling, refine SP logout (#13877)
Browse files Browse the repository at this point in the history
* [MSAL]Add AAD error handling, refine SP logout

* [MSAL]support MSI object/resource id

* [MSAL]support MSI object/resource id

* [MSAL]support MSI object/resource id

* [MSAL]support MSI object/resource id

* [MSAL]support MSI object/resource id
  • Loading branch information
qianwens authored Jun 10, 2020
1 parent 8883307 commit 6e46db4
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 40 deletions.
40 changes: 25 additions & 15 deletions src/azure-cli-core/azure/cli/core/_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ def _msal_app(self):
# Initialize _msal_app for logout, since Azure Identity doesn't provide the functionality for logout
from msal import PublicClientApplication
# sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py:95
from azure.identity._internal.persistent_cache import load_persistent_cache
from azure.identity._internal.persistent_cache import load_user_cache

# Store for user token persistence
cache = load_persistent_cache(self.allow_unencrypted)
cache = load_user_cache(self.allow_unencrypted)
# Build the authority in MSAL style
msal_authority = "https://{}/{}".format(self.authority, self.tenant_id)
return PublicClientApplication(authority=msal_authority, client_id=self.client_id, token_cache=cache)
Expand Down Expand Up @@ -168,38 +168,46 @@ def login_with_service_principal_certificate(self, client_id, certificate_path):
def login_with_managed_identity(self, resource, identity_id=None):
from msrestazure.tools import is_valid_resource_id
from requests import HTTPError
from azure.core.exceptions import ClientAuthenticationError

credential = None
id_type = None
scope = resource.rstrip('/') + '/.default'
if identity_id:
# Try resource ID
if is_valid_resource_id(identity_id):
# TODO: Support resource ID in Azure Identity
credential = ManagedIdentityCredential(resource_id=identity_id)
credential = ManagedIdentityCredential(identity_config={"msi_res_id": identity_id})
id_type = self.MANAGED_IDENTITY_RESOURCE_ID
else:
authenticated = False
try:
# Try client ID
credential = ManagedIdentityCredential(client_id=identity_id)
credential.get_token(scope)
id_type = self.MANAGED_IDENTITY_CLIENT_ID
authenticated = True
except ClientAuthenticationError as e:
logger.debug('Managed Identity authentication error: %s', e.message)
logger.info('Username is not an MSI client id')
except HTTPError as ex:
if ex.response.reason == 'Bad Request' and ex.response.status == 400:
logger.info('Sniff: not an MSI client id')
logger.info('Username is not an MSI client id')
else:
raise

if not authenticated:
try:
# Try object ID
# TODO: Support resource ID in Azure Identity
credential = ManagedIdentityCredential(object_id=identity_id)
credential = ManagedIdentityCredential(identity_config={"object_id": identity_id})
credential.get_token(scope)
id_type = self.MANAGED_IDENTITY_OBJECT_ID
authenticated = True
except ClientAuthenticationError as e:
logger.debug('Managed Identity authentication error: %s', e.message)
logger.info('Username is not an MSI object id')
except HTTPError as ex:
if ex.response.reason == 'Bad Request' and ex.response.status == 400:
logger.info('Sniff: not an MSI object id')
logger.info('Username is not an MSI object id')
else:
raise

Expand Down Expand Up @@ -258,24 +266,26 @@ def _decode_managed_identity_token(credential, resource):
decoded = json.loads(decoded_str)
return decoded

def get_user(self, user_or_sp=None):
accounts = self._msal_app.get_accounts(user_or_sp) if user_or_sp else self._msal_app.get_accounts()
def get_user(self, user=None):
accounts = self._msal_app.get_accounts(user) if user else self._msal_app.get_accounts()
return accounts

def logout_user(self, user_or_sp):
accounts = self._msal_app.get_accounts(user_or_sp)
def logout_user(self, user):
accounts = self._msal_app.get_accounts(user)
logger.info('Before account removal:')
logger.info(json.dumps(accounts))

# `accounts` are the same user in all tenants, log out all of them
for account in accounts:
self._msal_app.remove_account(account)

accounts = self._msal_app.get_accounts(user_or_sp)
accounts = self._msal_app.get_accounts(user)
logger.info('After account removal:')
logger.info(json.dumps(accounts))

def logout_sp(self, sp):
# remove service principal secrets
self._msal_store.remove_cached_creds(user_or_sp)
self._msal_store.remove_cached_creds(sp)

def logout_all(self):
# TODO: Support multi-authority logout
Expand Down Expand Up @@ -423,7 +433,7 @@ def add_credential(self, credential):
"refreshToken": refresh_token[0]['secret'],
"_clientId": _CLIENT_ID,
"_authority": self._cli_ctx.cloud.endpoints.active_directory.rstrip('/') +
"/" + credential._auth_record.tenant_id,
"/" + credential._auth_record.tenant_id, # pylint: disable=bad-continuation
"isMRRT": True
}
self.adal_token_cache.add([entry])
Expand Down
2 changes: 2 additions & 0 deletions src/azure-cli-core/azure/cli/core/_msal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ class AdalRefreshTokenBasedClientApplication(ClientApplication):
"""
This is added only for vmssh feature.
It is a temporary solution and will deprecate after MSAL adopted completely.
todo: msal
"""
def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
self, authority, scopes, account, **kwargs):
# pylint: disable=line-too-long
return self._acquire_token_silent_by_finding_specific_refresh_token(
authority, scopes, None, **kwargs)

# pylint:disable=arguments-differ
def _acquire_token_silent_by_finding_specific_refresh_token(
self, authority, scopes, query,
rt_remover=None, break_condition=lambda response: False, **kwargs):
Expand Down
13 changes: 6 additions & 7 deletions src/azure-cli-core/azure/cli/core/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def _get_authority_url(cli_ctx, tenant):


def get_credential_types(cli_ctx):

class CredentialType(Enum): # pylint: disable=too-few-public-methods
cloud = get_active_cloud(cli_ctx)
management = cli_ctx.cloud.endpoints.management
Expand Down Expand Up @@ -110,6 +109,7 @@ def __init__(self, storage=None, auth_ctx_factory=None, use_global_creds_cache=T
self._ad = self.cli_ctx.cloud.endpoints.active_directory
self._adal_cache = ADALCredentialCache(cli_ctx=self.cli_ctx)

# pylint: disable=too-many-branches,too-many-statements
def login(self,
interactive,
username,
Expand Down Expand Up @@ -428,9 +428,6 @@ def logout(self, user_or_sp, clear_credential):
# Always remove credential from the legacy cred cache, regardless of MSAL cache, to be deprecated
adal_cache = ADALCredentialCache(cli_ctx=self.cli_ctx)
adal_cache.remove_cached_creds(user_or_sp)
# remove service principle secret
msal_cache = MSALSecretStore()
msal_cache.remove_cached_creds(user_or_sp)

logger.warning('Account %s was logged out from Azure CLI', user_or_sp)
else:
Expand All @@ -453,7 +450,8 @@ def logout(self, user_or_sp, clear_credential):
'To clear the credential, run `az logout --username %s --clear-credential`.',
user_or_sp, user_or_sp)
else:
logger.warning("The credential of %s was not found from MSAL encrypted cache.", user_or_sp)
# remove service principle secret
identity.logout_sp(user_or_sp)

def logout_all(self, clear_credential):
self._storage[_SUBSCRIPTIONS] = []
Expand Down Expand Up @@ -637,8 +635,9 @@ def refresh_accounts(self, subscription_finder=None):
subscriptions = subscription_finder.find_using_specific_tenant(tenant, identity_credential)
else:
# pylint: disable=protected-access
subscriptions = subscription_finder.find_using_common_tenant(identity_credential._auth_record,
identity_credential)
subscriptions = subscription_finder. \
find_using_common_tenant(identity_credential._auth_record, # pylint: disable=protected-access
identity_credential)
except Exception as ex: # pylint: disable=broad-except
logger.warning("Refreshing for '%s' failed with an error '%s'. The existing accounts were not "
"modified. You can run 'az login' later to explicitly refresh them", user_name, ex)
Expand Down
31 changes: 13 additions & 18 deletions src/azure-cli-core/azure/cli/core/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,24 +53,19 @@ def _get_token(self, *scopes):
if in_cloud_console():
AuthenticationWrapper._log_hostname()

raise CLIError("Credentials have expired due to inactivity or "
"configuration of your account was changed. {}Error details: {}"
.format("Please run 'az login'. " if not in_cloud_console() else '', err))
# todo: error type
# err = (getattr(err, 'error_response', None) or {}).get('error_description') or ''
# if 'AADSTS70008' in err: # all errors starting with 70008 should be creds expiration related
# raise CLIError("Credentials have expired due to inactivity. {}".format(
# "Please run 'az login'" if not in_cloud_console() else ''))
# if 'AADSTS50079' in err:
# raise CLIError("Configuration of your account was changed. {}".format(
# "Please run 'az login'" if not in_cloud_console() else ''))
# if 'AADSTS50173' in err:
# raise CLIError("The credential data used by CLI has been expired because you might have changed or "
# "reset the password. {}".format(
# "Please clear browser's cookies and run 'az login'"
# if not in_cloud_console() else ''))
#
# raise CLIError(err)
err = getattr(err, 'message', None) or ''
if 'AADSTS70008' in err: # all errors starting with 70008 should be creds expiration related
raise CLIError("Credentials have expired due to inactivity. {}".format(
"Please run 'az login'" if not in_cloud_console() else ''))
if 'AADSTS50079' in err:
raise CLIError("Configuration of your account was changed. {}".format(
"Please run 'az login'" if not in_cloud_console() else ''))
if 'AADSTS50173' in err:
raise CLIError("The credential data used by CLI has been expired because you might have changed or "
"reset the password. {}".format(
"Please clear browser's cookies and run 'az login'"
if not in_cloud_console() else ''))
raise CLIError(err)
except requests.exceptions.SSLError as err:
from .util import SSLERROR_TEMPLATE
raise CLIError(SSLERROR_TEMPLATE.format(str(err)))
Expand Down

0 comments on commit 6e46db4

Please sign in to comment.