Skip to content

Commit

Permalink
{Profile} az account get-access-token: Show expiresOn for managed…
Browse files Browse the repository at this point in the history
… identity (#20219)
  • Loading branch information
jiasli authored Nov 17, 2021
1 parent 9a0d516 commit d7e5ede
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 71 deletions.
50 changes: 24 additions & 26 deletions src/azure-cli-core/azure/cli/core/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,37 +361,42 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No
raise CLIError("Please specify only one of subscription and tenant, not both")

account = self.get_subscription(subscription)
resource = resource or self.cli_ctx.cloud.endpoints.active_directory_resource_id

identity_type, identity_id = Profile._try_parse_msi_account_name(account)
if identity_type:
# MSI
# managed identity
if tenant:
raise CLIError("Tenant shouldn't be specified for MSI account")
msi_creds = MsiAccountTypes.msi_auth_factory(identity_type, identity_id, resource)
msi_creds.set_token()
token_entry = msi_creds.token
creds = (token_entry['token_type'], token_entry['access_token'], token_entry)
raise CLIError("Tenant shouldn't be specified for managed identity account")
from .auth.util import scopes_to_resource
msi_creds = MsiAccountTypes.msi_auth_factory(identity_type, identity_id,
scopes_to_resource(scopes))
sdk_token = msi_creds.get_token(*scopes)
elif in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID):
# Cloud Shell
# Cloud Shell, which is just a system-assigned managed identity.
if tenant:
raise CLIError("Tenant shouldn't be specified for Cloud Shell account")
creds = self._get_token_from_cloud_shell(resource)
from .auth.util import scopes_to_resource
msi_creds = MsiAccountTypes.msi_auth_factory(MsiAccountTypes.system_assigned, identity_id,
scopes_to_resource(scopes))
sdk_token = msi_creds.get_token(*scopes)
else:
credential = self._create_credential(account, tenant)
token = credential.get_token(*scopes)
sdk_token = credential.get_token(*scopes)

import datetime
expiresOn = datetime.datetime.fromtimestamp(token.expires_on).strftime("%Y-%m-%d %H:%M:%S.%f")
# Convert epoch int 'expires_on' to datetime string 'expiresOn' for backward compatibility
# WARNING: expiresOn is deprecated and will be removed in future release.
import datetime
expiresOn = datetime.datetime.fromtimestamp(sdk_token.expires_on).strftime("%Y-%m-%d %H:%M:%S.%f")

token_entry = {
'accessToken': token.token,
'expires_on': token.expires_on,
'expiresOn': expiresOn
}
token_entry = {
'accessToken': sdk_token.token,
'expires_on': sdk_token.expires_on, # epoch int, like 1605238724
'expiresOn': expiresOn # datetime string, like "2020-11-12 13:50:47.114324"
}

# (tokenType, accessToken, tokenEntry)
creds = 'Bearer', sdk_token.token, token_entry

# (tokenType, accessToken, tokenEntry)
creds = 'Bearer', token.token, token_entry
# (cred, subscription, tenant)
return (creds,
None if tenant else str(account[_SUBSCRIPTION_ID]),
Expand Down Expand Up @@ -695,13 +700,6 @@ def get_installation_id(self):
self._storage[_INSTALLATION_ID] = installation_id
return installation_id

def _get_token_from_cloud_shell(self, resource): # pylint: disable=no-self-use
from azure.cli.core.auth.adal_authentication import MSIAuthenticationWrapper
auth = MSIAuthenticationWrapper(resource=resource)
auth.set_token()
token_entry = auth.token
return (token_entry['token_type'], token_entry['access_token'], token_entry)


class MsiAccountTypes:
# pylint: disable=no-method-argument,no-self-argument
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
# --------------------------------------------------------------------------------------------

import requests
from azure.core.credentials import AccessToken
from knack.log import get_logger
from msrestazure.azure_active_directory import MSIAuthentication

from .util import _normalize_scopes, scopes_to_resource
from .util import _normalize_scopes, scopes_to_resource, AccessToken

logger = get_logger(__name__)

Expand Down
11 changes: 8 additions & 3 deletions src/azure-cli-core/azure/cli/core/auth/msal_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from knack.util import CLIError
from msal import PublicClientApplication, ConfidentialClientApplication

from .util import check_result
from .util import check_result, AccessToken

# OAuth 2.0 client credentials flow parameter
# https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow
Expand Down Expand Up @@ -137,9 +137,14 @@ def _build_sdk_access_token(token_entry):
import time
request_time = int(time.time())

# MSAL token entry sample:
# {
# 'access_token': 'eyJ0eXAiOiJKV...',
# 'token_type': 'Bearer',
# 'expires_in': 1618
# }

# Importing azure.core.credentials.AccessToken is expensive.
# This can slow down commands that doesn't need azure.core, like `az account get-access-token`.
# So We define our own AccessToken.
from collections import namedtuple
AccessToken = namedtuple("AccessToken", ["token", "expires_on"])
return AccessToken(token_entry["access_token"], request_time + token_entry["expires_in"])
2 changes: 1 addition & 1 deletion src/azure-cli-core/azure/cli/core/auth/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# pylint: disable=protected-access

import unittest
from ..util import scopes_to_resource, resource_to_scopes, _normalize_scopes, _generate_login_command
from azure.cli.core.auth.util import scopes_to_resource, resource_to_scopes, _normalize_scopes, _generate_login_command


class TestUtil(unittest.TestCase):
Expand Down
5 changes: 5 additions & 0 deletions src/azure-cli-core/azure/cli/core/auth/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@
# --------------------------------------------------------------------------------------------

import os
from collections import namedtuple

from knack.log import get_logger

logger = get_logger(__name__)


AccessToken = namedtuple("AccessToken", ["token", "expires_on"])


def aad_error_handler(error, **kwargs):
""" Handle the error from AAD server returned by ADAL or MSAL. """

Expand Down
Loading

0 comments on commit d7e5ede

Please sign in to comment.