diff --git a/src/azure-cli-core/azure/cli/core/_profile.py b/src/azure-cli-core/azure/cli/core/_profile.py index d6330e8874e..28a7ba493e2 100644 --- a/src/azure-cli-core/azure/cli/core/_profile.py +++ b/src/azure-cli-core/azure/cli/core/_profile.py @@ -361,37 +361,42 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No raise CLIError("Please specify only one of subscription and tenant, not both") account = self.get_subscription(subscription) - resource = resource or self.cli_ctx.cloud.endpoints.active_directory_resource_id identity_type, identity_id = Profile._try_parse_msi_account_name(account) if identity_type: - # MSI + # managed identity if tenant: - raise CLIError("Tenant shouldn't be specified for MSI account") - msi_creds = MsiAccountTypes.msi_auth_factory(identity_type, identity_id, resource) - msi_creds.set_token() - token_entry = msi_creds.token - creds = (token_entry['token_type'], token_entry['access_token'], token_entry) + raise CLIError("Tenant shouldn't be specified for managed identity account") + from .auth.util import scopes_to_resource + msi_creds = MsiAccountTypes.msi_auth_factory(identity_type, identity_id, + scopes_to_resource(scopes)) + sdk_token = msi_creds.get_token(*scopes) elif in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID): - # Cloud Shell + # Cloud Shell, which is just a system-assigned managed identity. if tenant: raise CLIError("Tenant shouldn't be specified for Cloud Shell account") - creds = self._get_token_from_cloud_shell(resource) + from .auth.util import scopes_to_resource + msi_creds = MsiAccountTypes.msi_auth_factory(MsiAccountTypes.system_assigned, identity_id, + scopes_to_resource(scopes)) + sdk_token = msi_creds.get_token(*scopes) else: credential = self._create_credential(account, tenant) - token = credential.get_token(*scopes) + sdk_token = credential.get_token(*scopes) - import datetime - expiresOn = datetime.datetime.fromtimestamp(token.expires_on).strftime("%Y-%m-%d %H:%M:%S.%f") + # Convert epoch int 'expires_on' to datetime string 'expiresOn' for backward compatibility + # WARNING: expiresOn is deprecated and will be removed in future release. + import datetime + expiresOn = datetime.datetime.fromtimestamp(sdk_token.expires_on).strftime("%Y-%m-%d %H:%M:%S.%f") - token_entry = { - 'accessToken': token.token, - 'expires_on': token.expires_on, - 'expiresOn': expiresOn - } + token_entry = { + 'accessToken': sdk_token.token, + 'expires_on': sdk_token.expires_on, # epoch int, like 1605238724 + 'expiresOn': expiresOn # datetime string, like "2020-11-12 13:50:47.114324" + } + + # (tokenType, accessToken, tokenEntry) + creds = 'Bearer', sdk_token.token, token_entry - # (tokenType, accessToken, tokenEntry) - creds = 'Bearer', token.token, token_entry # (cred, subscription, tenant) return (creds, None if tenant else str(account[_SUBSCRIPTION_ID]), @@ -695,13 +700,6 @@ def get_installation_id(self): self._storage[_INSTALLATION_ID] = installation_id return installation_id - def _get_token_from_cloud_shell(self, resource): # pylint: disable=no-self-use - from azure.cli.core.auth.adal_authentication import MSIAuthenticationWrapper - auth = MSIAuthenticationWrapper(resource=resource) - auth.set_token() - token_entry = auth.token - return (token_entry['token_type'], token_entry['access_token'], token_entry) - class MsiAccountTypes: # pylint: disable=no-method-argument,no-self-argument diff --git a/src/azure-cli-core/azure/cli/core/auth/adal_authentication.py b/src/azure-cli-core/azure/cli/core/auth/adal_authentication.py index e8ac5233f58..d53ac588a5b 100644 --- a/src/azure-cli-core/azure/cli/core/auth/adal_authentication.py +++ b/src/azure-cli-core/azure/cli/core/auth/adal_authentication.py @@ -4,11 +4,10 @@ # -------------------------------------------------------------------------------------------- import requests -from azure.core.credentials import AccessToken from knack.log import get_logger from msrestazure.azure_active_directory import MSIAuthentication -from .util import _normalize_scopes, scopes_to_resource +from .util import _normalize_scopes, scopes_to_resource, AccessToken logger = get_logger(__name__) diff --git a/src/azure-cli-core/azure/cli/core/auth/msal_authentication.py b/src/azure-cli-core/azure/cli/core/auth/msal_authentication.py index 3c0d46e043e..fe13ee3e4d4 100644 --- a/src/azure-cli-core/azure/cli/core/auth/msal_authentication.py +++ b/src/azure-cli-core/azure/cli/core/auth/msal_authentication.py @@ -14,7 +14,7 @@ from knack.util import CLIError from msal import PublicClientApplication, ConfidentialClientApplication -from .util import check_result +from .util import check_result, AccessToken # OAuth 2.0 client credentials flow parameter # https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow @@ -137,9 +137,14 @@ def _build_sdk_access_token(token_entry): import time request_time = int(time.time()) + # MSAL token entry sample: + # { + # 'access_token': 'eyJ0eXAiOiJKV...', + # 'token_type': 'Bearer', + # 'expires_in': 1618 + # } + # Importing azure.core.credentials.AccessToken is expensive. # This can slow down commands that doesn't need azure.core, like `az account get-access-token`. # So We define our own AccessToken. - from collections import namedtuple - AccessToken = namedtuple("AccessToken", ["token", "expires_on"]) return AccessToken(token_entry["access_token"], request_time + token_entry["expires_in"]) diff --git a/src/azure-cli-core/azure/cli/core/auth/tests/test_util.py b/src/azure-cli-core/azure/cli/core/auth/tests/test_util.py index f5db382d736..c96e5a446ed 100644 --- a/src/azure-cli-core/azure/cli/core/auth/tests/test_util.py +++ b/src/azure-cli-core/azure/cli/core/auth/tests/test_util.py @@ -6,7 +6,7 @@ # pylint: disable=protected-access import unittest -from ..util import scopes_to_resource, resource_to_scopes, _normalize_scopes, _generate_login_command +from azure.cli.core.auth.util import scopes_to_resource, resource_to_scopes, _normalize_scopes, _generate_login_command class TestUtil(unittest.TestCase): diff --git a/src/azure-cli-core/azure/cli/core/auth/util.py b/src/azure-cli-core/azure/cli/core/auth/util.py index ff6af520938..0186cdc3ad2 100644 --- a/src/azure-cli-core/azure/cli/core/auth/util.py +++ b/src/azure-cli-core/azure/cli/core/auth/util.py @@ -4,11 +4,16 @@ # -------------------------------------------------------------------------------------------- import os +from collections import namedtuple + from knack.log import get_logger logger = get_logger(__name__) +AccessToken = namedtuple("AccessToken", ["token", "expires_on"]) + + def aad_error_handler(error, **kwargs): """ Handle the error from AAD server returned by ADAL or MSAL. """ diff --git a/src/azure-cli-core/azure/cli/core/tests/test_profile.py b/src/azure-cli-core/azure/cli/core/tests/test_profile.py index da8431d4fb9..df43c5f66da 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_profile.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_profile.py @@ -5,45 +5,39 @@ # pylint: disable=protected-access import json -import os -import sys -import unittest -from unittest import mock -import re import datetime - +import unittest from copy import deepcopy - -from azure.core.credentials import AccessToken +from unittest import mock from azure.cli.core._profile import (Profile, SubscriptionFinder, _attach_token_tenant, _transform_subscription_for_multiapi) - -from azure.mgmt.resource.subscriptions.models import \ - (Subscription, SubscriptionPolicies, SpendingLimit, ManagedByTenant, TenantIdDescription) - +from azure.cli.core.auth.util import AccessToken from azure.cli.core.mock import DummyCli from azure.identity import AuthenticationRecord +from azure.mgmt.resource.subscriptions.models import \ + (Subscription, SubscriptionPolicies, SpendingLimit, ManagedByTenant) from knack.util import CLIError - MOCK_ACCESS_TOKEN = "mock_access_token" -MOCK_EXPIRES_ON = 1630920323 +MOCK_EXPIRES_ON_STR = "1630920323" +MOCK_EXPIRES_ON_INT = 1630920323 +MOCK_EXPIRES_ON_DATETIME = datetime.datetime.fromtimestamp(MOCK_EXPIRES_ON_INT).strftime("%Y-%m-%d %H:%M:%S.%f") BEARER = 'Bearer' class CredentialMock: def __init__(self, *args, **kwargs): + # If get_token_scopes is checked, make sure to create a new instance of CredentialMock + # to avoid interference from other tests. + self.get_token_scopes = None super().__init__() def get_token(self, *scopes, **kwargs): - from azure.core.credentials import AccessToken - import time - now = int(time.time()) - # Mock sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py:230 - return AccessToken(MOCK_ACCESS_TOKEN, MOCK_EXPIRES_ON) + self.get_token_scopes = scopes + return AccessToken(MOCK_ACCESS_TOKEN, MOCK_EXPIRES_ON_INT) # Used as the return_value of azure.cli.core.auth.identity.Identity.get_user_credential @@ -53,16 +47,20 @@ def get_token(self, *scopes, **kwargs): class MSRestAzureAuthStub: + def __init__(self, *args, **kwargs): self._token = { 'token_type': 'Bearer', - 'access_token': TestProfile.test_msi_access_token + 'access_token': TestProfile.test_msi_access_token, + 'expires_on': MOCK_EXPIRES_ON_STR } self.set_token_invoked_count = 0 self.token_read_count = 0 + self.get_token_scopes = None self.client_id = kwargs.get('client_id') self.object_id = kwargs.get('object_id') self.msi_res_id = kwargs.get('msi_res_id') + self.resource = kwargs.get('resource') def set_token(self): self.set_token_invoked_count += 1 @@ -76,6 +74,10 @@ def token(self): def token(self, value): self._token = value + def get_token(self, *args, **kwargs): + self.get_token_scopes = args + return AccessToken(self.token['access_token'], int(self.token['expires_on'])) + class TestProfile(unittest.TestCase): @@ -269,7 +271,8 @@ def setUpClass(cls): 'authority_type': 'MSSTS' }] - cls.msal_scopes = ['https://foo/.default'] + cls.adal_resource = 'https://foo/' + cls.msal_scopes = ['https://foo//.default'] cls.service_principal_id = "00000001-0000-0000-0000-000000000000" cls.service_principal_secret = "test_secret" @@ -1033,8 +1036,10 @@ def test_get_login_credentials_msi_user_assigned_with_res_id(self): self.assertTrue(cred.token_read_count) self.assertTrue(cred.msi_res_id, test_res_id) - @mock.patch('azure.cli.core.auth.identity.Identity.get_user_credential', return_value=credential_mock) + @mock.patch('azure.cli.core.auth.identity.Identity.get_user_credential') def test_get_raw_token(self, get_user_credential_mock): + credential_mock_temp = CredentialMock() + get_user_credential_mock.return_value = credential_mock_temp cli = DummyCli() # setup storage_mock = {'subscriptions': None} @@ -1046,7 +1051,7 @@ def test_get_raw_token(self, get_user_credential_mock): # action # Get token with ADAL-style resource - resource_result = profile.get_raw_token(resource='https://foo') + resource_result = profile.get_raw_token(resource=self.adal_resource) # Get token with MSAL-style scopes scopes_result = profile.get_raw_token(scopes=self.msal_scopes) @@ -1056,18 +1061,23 @@ def test_get_raw_token(self, get_user_credential_mock): self.assertEqual(creds[0], 'Bearer') self.assertEqual(creds[1], MOCK_ACCESS_TOKEN) - self.assertEqual(creds[2]['expires_on'], MOCK_EXPIRES_ON) + self.assertEqual(creds[2]['expires_on'], MOCK_EXPIRES_ON_INT) + self.assertEqual(creds[2]['expiresOn'], MOCK_EXPIRES_ON_DATETIME) # subscription should be set self.assertEqual(sub, self.subscription1.subscription_id) self.assertEqual(tenant, self.tenant_id) # Test get_raw_token with tenant - creds, sub, tenant = profile.get_raw_token(resource='https://foo', tenant=self.tenant_id) + creds, sub, tenant = profile.get_raw_token(resource=self.adal_resource, tenant=self.tenant_id) + + # verify + assert list(credential_mock_temp.get_token_scopes) == self.msal_scopes self.assertEqual(creds[0], 'Bearer') self.assertEqual(creds[1], MOCK_ACCESS_TOKEN) - self.assertEqual(creds[2]['expires_on'], MOCK_EXPIRES_ON) + self.assertEqual(creds[2]['expires_on'], MOCK_EXPIRES_ON_INT) + self.assertEqual(creds[2]['expiresOn'], MOCK_EXPIRES_ON_DATETIME) # subscription shouldn't be set self.assertIsNone(sub) @@ -1075,7 +1085,8 @@ def test_get_raw_token(self, get_user_credential_mock): @mock.patch('azure.cli.core.auth.identity.Identity.get_service_principal_credential') def test_get_raw_token_for_sp(self, get_service_principal_credential_mock): - get_service_principal_credential_mock.return_value = CredentialMock() + credential_mock_temp = CredentialMock() + get_service_principal_credential_mock.return_value = credential_mock_temp cli = DummyCli() # setup storage_mock = {'subscriptions': None} @@ -1085,24 +1096,28 @@ def test_get_raw_token_for_sp(self, get_service_principal_credential_mock): True) profile._set_subscriptions(consolidated) # action - creds, sub, tenant = profile.get_raw_token(resource='https://foo') + creds, sub, tenant = profile.get_raw_token(resource=self.adal_resource) # verify + assert list(credential_mock_temp.get_token_scopes) == self.msal_scopes + self.assertEqual(creds[0], BEARER) self.assertEqual(creds[1], MOCK_ACCESS_TOKEN) # the last in the tuple is the whole token entry which has several fields - self.assertEqual(creds[2]['expires_on'], MOCK_EXPIRES_ON) + self.assertEqual(creds[2]['expires_on'], MOCK_EXPIRES_ON_INT) + self.assertEqual(creds[2]['expiresOn'], MOCK_EXPIRES_ON_DATETIME) # subscription should be set self.assertEqual(sub, self.subscription1.subscription_id) self.assertEqual(tenant, self.tenant_id) # Test get_raw_token with tenant - creds, sub, tenant = profile.get_raw_token(resource='https://foo', tenant=self.tenant_id) + creds, sub, tenant = profile.get_raw_token(resource=self.adal_resource, tenant=self.tenant_id) self.assertEqual(creds[0], BEARER) self.assertEqual(creds[1], MOCK_ACCESS_TOKEN) - self.assertEqual(creds[2]['expires_on'], MOCK_EXPIRES_ON) + self.assertEqual(creds[2]['expires_on'], MOCK_EXPIRES_ON_INT) + self.assertEqual(creds[2]['expiresOn'], MOCK_EXPIRES_ON_DATETIME) # subscription shouldn't be set self.assertIsNone(sub) @@ -1122,20 +1137,34 @@ def test_get_raw_token_msi_system_assigned(self, mock_msi_auth): True) profile._set_subscriptions(consolidated) - mock_msi_auth.side_effect = MSRestAzureAuthStub + mi_auth_instance = None + + def mi_auth_factory(*args, **kwargs): + nonlocal mi_auth_instance + mi_auth_instance = MSRestAzureAuthStub(*args, **kwargs) + return mi_auth_instance + + mock_msi_auth.side_effect = mi_auth_factory # action - cred, subscription_id, tenant_id = profile.get_raw_token(resource='http://test_resource') + cred, subscription_id, tenant_id = profile.get_raw_token(resource=self.adal_resource) + + # Make sure resource/scopes are passed to MSIAuthenticationWrapper + assert mi_auth_instance.resource == self.adal_resource + assert list(mi_auth_instance.get_token_scopes) == self.msal_scopes - # assert self.assertEqual(subscription_id, test_subscription_id) self.assertEqual(cred[0], 'Bearer') self.assertEqual(cred[1], TestProfile.test_msi_access_token) + + # Make sure expires_on and expiresOn are set + self.assertEqual(cred[2]['expires_on'], MOCK_EXPIRES_ON_INT) + self.assertEqual(cred[2]['expiresOn'], MOCK_EXPIRES_ON_DATETIME) self.assertEqual(subscription_id, test_subscription_id) self.assertEqual(tenant_id, test_tenant_id) # verify tenant shouldn't be specified for MSI account - with self.assertRaisesRegexp(CLIError, "MSI"): + with self.assertRaisesRegexp(CLIError, "Tenant shouldn't be specified"): cred, subscription_id, _ = profile.get_raw_token(resource='http://test_resource', tenant=self.tenant_id) @mock.patch('azure.cli.core._profile.in_cloud_console', autospec=True) @@ -1155,15 +1184,29 @@ def test_get_raw_token_in_cloud_console(self, mock_msi_auth, mock_in_cloud_conso consolidated[0]['user']['cloudShellID'] = True profile._set_subscriptions(consolidated) - mock_msi_auth.side_effect = MSRestAzureAuthStub + mi_auth_instance = None + + def mi_auth_factory(*args, **kwargs): + nonlocal mi_auth_instance + mi_auth_instance = MSRestAzureAuthStub(*args, **kwargs) + return mi_auth_instance + + mock_msi_auth.side_effect = mi_auth_factory # action - cred, subscription_id, tenant_id = profile.get_raw_token(resource='http://test_resource') + cred, subscription_id, tenant_id = profile.get_raw_token(resource=self.adal_resource) + + # Make sure resource/scopes are passed to MSIAuthenticationWrapper + assert mi_auth_instance.resource == self.adal_resource + assert list(mi_auth_instance.get_token_scopes) == self.msal_scopes - # assert self.assertEqual(subscription_id, test_subscription_id) self.assertEqual(cred[0], 'Bearer') self.assertEqual(cred[1], TestProfile.test_msi_access_token) + + # Make sure expires_on and expiresOn are set + self.assertEqual(cred[2]['expires_on'], MOCK_EXPIRES_ON_INT) + self.assertEqual(cred[2]['expiresOn'], MOCK_EXPIRES_ON_DATETIME) self.assertEqual(subscription_id, test_subscription_id) self.assertEqual(tenant_id, test_tenant_id) diff --git a/src/azure-cli/azure/cli/command_modules/profile/custom.py b/src/azure-cli/azure/cli/command_modules/profile/custom.py index 8319689d6ef..6ed4e38e965 100644 --- a/src/azure-cli/azure/cli/command_modules/profile/custom.py +++ b/src/azure-cli/azure/cli/command_modules/profile/custom.py @@ -80,7 +80,7 @@ def get_access_token(cmd, subscription=None, resource=None, scopes=None, resourc 'tokenType': creds[0], 'accessToken': creds[1], # 'expires_on': creds[2].get('expires_on', None), - 'expiresOn': creds[2].get('expiresOn', None), + 'expiresOn': creds[2]['expiresOn'], 'tenant': tenant } if subscription: