Skip to content

Commit

Permalink
[Core] Add workaround for cross-tenant authentication with Track 2 SD…
Browse files Browse the repository at this point in the history
…Ks (#16797)

Co-authored-by: Feiyue Yu <[email protected]>
  • Loading branch information
jiasli and qwordy authored Apr 9, 2021
1 parent a3604d1 commit 039fa21
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 29 deletions.
5 changes: 5 additions & 0 deletions src/azure-cli-core/azure/cli/core/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ def get_login_credentials(self, resource=None, subscription_id=None, aux_subscri

identity_type, identity_id = Profile._try_parse_msi_account_name(account)

# Make sure external_tenants_info only contains real external tenant (no current tenant).
external_tenants_info = []
if aux_tenants:
external_tenants_info = [tenant for tenant in aux_tenants if tenant != account[_TENANT_ID]]
Expand All @@ -560,6 +561,10 @@ def get_login_credentials(self, resource=None, subscription_id=None, aux_subscri
if sub[_TENANT_ID] != account[_TENANT_ID]:
external_tenants_info.append(sub[_TENANT_ID])

if external_tenants_info and (identity_type or in_cloud_console()):
raise CLIError("Cross-tenant authentication is not supported by managed identity and Cloud Shell. "
"Please run `az login` with a user account or a service principal.")

if identity_type is None:
def _retrieve_token(sdk_resource=None):
# When called by
Expand Down
21 changes: 12 additions & 9 deletions src/azure-cli-core/azure/cli/core/adal_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _get_token(self, sdk_resource=None):
"""
external_tenant_tokens = None
try:
scheme, token, full_token = self._token_retriever(sdk_resource)
scheme, token, token_entry = self._token_retriever(sdk_resource)
if self._external_tenant_token_retriever:
external_tenant_tokens = self._external_tenant_token_retriever(sdk_resource)
except CLIError as err:
Expand All @@ -52,17 +52,20 @@ def _get_token(self, sdk_resource=None):
except requests.exceptions.ConnectionError as err:
raise CLIError('Please ensure you have network connection. Error detail: ' + str(err))

return scheme, token, full_token, external_tenant_tokens
# scheme: str. The token scheme. Should always be 'Bearer'.
# token: str. The raw access token.
# token_entry: dict. The full token entry.
# external_tenant_tokens: [(scheme: str, token: str, token_entry: dict), ...]
return scheme, token, token_entry, external_tenant_tokens

def get_all_tokens(self, *scopes):
scheme, token, full_token, external_tenant_tokens = self._get_token(_try_scopes_to_resource(scopes))
return scheme, token, full_token, external_tenant_tokens
return self._get_token(_try_scopes_to_resource(scopes))

# This method is exposed for Azure Core.
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
logger.debug("AdalAuthentication.get_token invoked by Track 2 SDK with scopes=%s", scopes)

_, token, full_token, _ = self._get_token(_try_scopes_to_resource(scopes))
_, token, token_entry, _ = self._get_token(_try_scopes_to_resource(scopes))

# NEVER use expiresIn (expires_in) as the token is cached and expiresIn will be already out-of date
# when being retrieved.
Expand Down Expand Up @@ -92,10 +95,10 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# "_clientId": "22800c35-46c2-4210-b8a7-d8c3ec3b526f",
# "_authority": "https://login.microsoftonline.com/54826b22-38d6-4fb2-bad9-b7b93a3e9c5a"
# }
if 'expiresOn' in full_token:
if 'expiresOn' in token_entry:
import datetime
expires_on_timestamp = int(_timestamp(
datetime.datetime.strptime(full_token['expiresOn'], '%Y-%m-%d %H:%M:%S.%f')))
datetime.datetime.strptime(token_entry['expiresOn'], '%Y-%m-%d %H:%M:%S.%f')))
return AccessToken(token, expires_on_timestamp)

# Cloud Shell (Managed Identity) token entry sample:
Expand All @@ -108,8 +111,8 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# "resource": "https://management.core.windows.net/",
# "token_type": "Bearer"
# }
if 'expires_on' in full_token:
return AccessToken(token, int(full_token['expires_on']))
if 'expires_on' in token_entry:
return AccessToken(token, int(token_entry['expires_on']))

from azure.cli.core.azclierror import CLIInternalError
raise CLIInternalError("No expiresOn or expires_on is available in the token entry.")
Expand Down
10 changes: 10 additions & 0 deletions src/azure-cli-core/azure/cli/core/commands/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,16 @@ def _get_mgmt_service_client(cli_ctx,
client_kwargs.update(_prepare_client_kwargs_track2(cli_ctx))
client_kwargs['credential_scopes'] = resource_to_scopes(resource)

# Track 2 currently lacks the ability to take external credentials.
# https://github.com/Azure/azure-sdk-for-python/issues/8313
# As a temporary workaround, manually add external tokens to 'x-ms-authorization-auxiliary' header.
# https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/authenticate-multi-tenant
if getattr(cred, "_external_tenant_token_retriever", None):
*_, external_tenant_tokens = cred.get_all_tokens(*resource_to_scopes(resource))
# Hard-code scheme to 'Bearer' as _BearerTokenCredentialPolicyBase._update_headers does.
client_kwargs['headers']['x-ms-authorization-auxiliary'] = \
', '.join("Bearer {}".format(t[1]) for t in external_tenant_tokens)

if subscription_bound:
client = client_type(cred, subscription_id, **client_kwargs)
else:
Expand Down
21 changes: 1 addition & 20 deletions src/azure-cli/azure/cli/command_modules/vm/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -3387,31 +3387,13 @@ def create_image_version(cmd, resource_group_name, gallery_name, gallery_image_n
# print(target_regions)
from msrestazure.tools import resource_id, is_valid_resource_id
from azure.cli.core.commands.client_factory import get_subscription_id
from azure.cli.core._profile import Profile

ImageVersionPublishingProfile, GalleryArtifactSource, ManagedArtifact, ImageVersion, TargetRegion = cmd.get_models(
'GalleryImageVersionPublishingProfile', 'GalleryArtifactSource', 'ManagedArtifact', 'GalleryImageVersion',
'TargetRegion')
aux_subscriptions = _get_image_version_aux_subscription(managed_image, os_snapshot, data_snapshots)
client = _compute_client_factory(cmd.cli_ctx, aux_subscriptions=aux_subscriptions)

# Auxiliary tokens, pass it to init or operation
external_bearer_token = None
if aux_subscriptions:
profile = Profile(cli_ctx=cmd.cli_ctx)
resource = cmd.cli_ctx.cloud.endpoints.active_directory_resource_id
cred, _, _ = profile.get_login_credentials(resource=resource,
aux_subscriptions=aux_subscriptions)
_, _, _, external_tokens = cred.get_all_tokens('https://management.azure.com/.default')
if external_tokens:
external_token = external_tokens[0]
if len(external_token) >= 2:
external_bearer_token = external_token[0] + ' ' + external_token[1]
else:
logger.warning('Getting external tokens failed.')
else:
logger.warning('Getting external tokens failed.')

location = location or _get_resource_group_location(cmd.cli_ctx, resource_group_name)
end_of_life_date = fix_gallery_image_date_info(end_of_life_date)
if managed_image and not is_valid_resource_id(managed_image):
Expand Down Expand Up @@ -3468,8 +3450,7 @@ def create_image_version(cmd, resource_group_name, gallery_name, gallery_image_n
gallery_name=gallery_name,
gallery_image_name=gallery_image_name,
gallery_image_version_name=gallery_image_version,
gallery_image_version=image_version,
headers={'x-ms-authorization-auxiliary': external_bearer_token}
gallery_image_version=image_version
)


Expand Down

0 comments on commit 039fa21

Please sign in to comment.