Skip to content

Commit

Permalink
Identity not use is chained (#31328)
Browse files Browse the repository at this point in the history
* not use is_chained

* update
  • Loading branch information
xiangyan99 authored Jul 26, 2023
1 parent 674698c commit 9f1bf3a
Show file tree
Hide file tree
Showing 16 changed files with 81 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from azure.core.exceptions import ClientAuthenticationError

from .. import CredentialUnavailableError
from .._internal import resolve_tenant
from .._internal import resolve_tenant, within_dac
from .._internal.decorators import log_get_token

CLI_NOT_FOUND = (
Expand Down Expand Up @@ -75,11 +75,9 @@ def __init__(
tenant_id: str = "",
additionally_allowed_tenants: Optional[List[str]] = None,
process_timeout: int = 10,
_is_chained: bool = False,
) -> None:

self.tenant_id = tenant_id
self._is_chained = _is_chained
self._additionally_allowed_tenants = additionally_allowed_tenants or []
self._process_timeout = process_timeout

Expand Down Expand Up @@ -123,7 +121,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
)
if tenant:
command += " --tenant-id " + tenant
output = _run_command(command, self._process_timeout, _is_chained=self._is_chained)
output = _run_command(command, self._process_timeout)

token = parse_token(output)
if not token:
Expand All @@ -133,7 +131,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
f"To mitigate this issue, please refer to the troubleshooting guidelines here at "
f"https://aka.ms/azsdk/python/identity/azdevclicredential/troubleshoot."
)
if self._is_chained:
if within_dac.get():
raise CredentialUnavailableError(message=message)
raise ClientAuthenticationError(message=message)

Expand Down Expand Up @@ -189,7 +187,7 @@ def sanitize_output(output: str) -> str:
return re.sub(r"\"token\": \"(.*?)(\"|$)", "****", output)


def _run_command(command: str, timeout: int, _is_chained: bool = False) -> str:
def _run_command(command: str, timeout: int) -> str:
# Ensure executable exists in PATH first. This avoids a subprocess call that would fail anyway.
if shutil.which(EXECUTABLE_NAME) is None:
raise CredentialUnavailableError(message=CLI_NOT_FOUND)
Expand Down Expand Up @@ -223,7 +221,7 @@ def _run_command(command: str, timeout: int, _is_chained: bool = False) -> str:
message = sanitize_output(ex.stderr)
else:
message = "Failed to invoke Azure Developer CLI"
if _is_chained:
if within_dac.get():
raise CredentialUnavailableError(message=message) from ex
raise ClientAuthenticationError(message=message) from ex
except OSError as ex:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from azure.core.exceptions import ClientAuthenticationError

from .. import CredentialUnavailableError
from .._internal import _scopes_to_resource, resolve_tenant
from .._internal import _scopes_to_resource, resolve_tenant, within_dac
from .._internal.decorators import log_get_token


Expand Down Expand Up @@ -53,11 +53,9 @@ def __init__(
tenant_id: str = "",
additionally_allowed_tenants: Optional[List[str]] = None,
process_timeout: int = 10,
_is_chained: bool = False,
) -> None:

self.tenant_id = tenant_id
self._is_chained = _is_chained
self._additionally_allowed_tenants = additionally_allowed_tenants or []
self._process_timeout = process_timeout

Expand Down Expand Up @@ -97,7 +95,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
)
if tenant:
command += " --tenant " + tenant
output = _run_command(command, self._process_timeout, _is_chained=self._is_chained)
output = _run_command(command, self._process_timeout)

token = parse_token(output)
if not token:
Expand All @@ -107,7 +105,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
f"To mitigate this issue, please refer to the troubleshooting guidelines here at "
f"https://aka.ms/azsdk/python/identity/azclicredential/troubleshoot."
)
if self._is_chained:
if within_dac.get():
raise CredentialUnavailableError(message=message)
raise ClientAuthenticationError(message=message)

Expand Down Expand Up @@ -165,7 +163,7 @@ def sanitize_output(output: str) -> str:
return re.sub(r"\"accessToken\": \"(.*?)(\"|$)", "****", output)


def _run_command(command: str, timeout: int, _is_chained: bool = False) -> str:
def _run_command(command: str, timeout: int) -> str:
# Ensure executable exists in PATH first. This avoids a subprocess call that would fail anyway.
if shutil.which(EXECUTABLE_NAME) is None:
raise CredentialUnavailableError(message=CLI_NOT_FOUND)
Expand Down Expand Up @@ -198,7 +196,7 @@ def _run_command(command: str, timeout: int, _is_chained: bool = False) -> str:
message = sanitize_output(ex.stderr)
else:
message = "Failed to invoke Azure CLI"
if _is_chained:
if within_dac.get():
raise CredentialUnavailableError(message=message) from ex
raise ClientAuthenticationError(message=message) from ex
except OSError as ex:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .azure_cli import get_safe_working_dir
from .. import CredentialUnavailableError
from .._internal import _scopes_to_resource, resolve_tenant
from .._internal import _scopes_to_resource, resolve_tenant, within_dac
from .._internal.decorators import log_get_token


Expand Down Expand Up @@ -67,11 +67,9 @@ def __init__(
tenant_id: str = "",
additionally_allowed_tenants: Optional[List[str]] = None,
process_timeout: int = 10,
_is_chained: bool = False
) -> None:

self.tenant_id = tenant_id
self._is_chained = _is_chained
self._additionally_allowed_tenants = additionally_allowed_tenants or []
self._process_timeout = process_timeout

Expand Down Expand Up @@ -109,7 +107,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
)
command_line = get_command_line(scopes, tenant_id)
output = run_command_line(command_line, self._process_timeout)
token = parse_token(output, _is_chained=self._is_chained)
token = parse_token(output)
return token


Expand Down Expand Up @@ -155,13 +153,13 @@ def start_process(args: List[str]) -> "subprocess.Popen":
return proc


def parse_token(output: str, _is_chained: bool = False) -> AccessToken:
def parse_token(output: str) -> AccessToken:
for line in output.split():
if line.startswith("azsdk%"):
_, token, expires_on = line.split("%")
return AccessToken(token, int(expires_on))

if _is_chained:
if within_dac.get():
raise CredentialUnavailableError(message='Unexpected output from Get-AzAccessToken: "{}"'.format(output))
raise ClientAuthenticationError(message='Unexpected output from Get-AzAccessToken: "{}"'.format(output))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from .. import CredentialUnavailableError
from .._constants import DEVELOPER_SIGN_ON_CLIENT_ID
from .._internal import AuthCodeRedirectServer, InteractiveCredential, wrap_exceptions
from .._internal import AuthCodeRedirectServer, InteractiveCredential, wrap_exceptions, within_dac


class InteractiveBrowserCredential(InteractiveCredential):
Expand Down Expand Up @@ -78,7 +78,6 @@ def __init__(self, **kwargs: Any) -> None:
else:
self._parsed_url = None

self._is_chained = kwargs.pop("_is_chained", False)
self._login_hint = kwargs.pop("login_hint", None)
self._timeout = kwargs.pop("timeout", 300)
self._server_class = kwargs.pop("_server_class", AuthCodeRedirectServer)
Expand Down Expand Up @@ -148,11 +147,11 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> Dict:
except socket.error as ex:
raise CredentialUnavailableError(message="Couldn't start an HTTP server.") from ex
if "access_token" not in result and "error_description" in result:
if self._is_chained:
if within_dac.get():
raise CredentialUnavailableError(message=result["error_description"])
raise ClientAuthenticationError(message=result.get("error_description"))
if "access_token" not in result:
if self._is_chained:
if within_dac.get():
raise CredentialUnavailableError(message="Failed to authenticate user")
raise ClientAuthenticationError(message="Failed to authenticate user")

Expand Down
31 changes: 12 additions & 19 deletions sdk/identity/azure-identity/azure/identity/_credentials/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from azure.core.credentials import AccessToken
from .._constants import EnvironmentVariables
from .._internal import get_default_authority, normalize_authority
from .._internal import get_default_authority, normalize_authority, within_dac
from .azure_powershell import AzurePowerShellCredential
from .browser import InteractiveBrowserCredential
from .chained import ChainedTokenCredential
Expand Down Expand Up @@ -170,37 +170,28 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
try:
# username and/or tenant_id are only required when the cache contains tokens for multiple identities
shared_cache = SharedTokenCacheCredential(
username=shared_cache_username,
tenant_id=shared_cache_tenant_id,
authority=authority,
_is_chained=True,
**kwargs
username=shared_cache_username, tenant_id=shared_cache_tenant_id, authority=authority, **kwargs
)
credentials.append(shared_cache)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.info("Shared token cache is unavailable: '%s'", ex)
if not exclude_visual_studio_code_credential:
credentials.append(VisualStudioCodeCredential(_is_chained=True, **vscode_args))
credentials.append(VisualStudioCodeCredential(**vscode_args))
if not exclude_cli_credential:
credentials.append(AzureCliCredential(process_timeout=process_timeout, _is_chained=True))
credentials.append(AzureCliCredential(process_timeout=process_timeout))
if not exclude_powershell_credential:
credentials.append(AzurePowerShellCredential(process_timeout=process_timeout, _is_chained=True))
credentials.append(AzurePowerShellCredential(process_timeout=process_timeout))
if not exclude_developer_cli_credential:
credentials.append(AzureDeveloperCliCredential(process_timeout=process_timeout, _is_chained=True))
credentials.append(AzureDeveloperCliCredential(process_timeout=process_timeout))
if not exclude_interactive_browser_credential:
if interactive_browser_client_id:
credentials.append(
InteractiveBrowserCredential(
tenant_id=interactive_browser_tenant_id,
client_id=interactive_browser_client_id,
_is_chained=True,
**kwargs
tenant_id=interactive_browser_tenant_id, client_id=interactive_browser_client_id, **kwargs
)
)
else:
credentials.append(
InteractiveBrowserCredential(tenant_id=interactive_browser_tenant_id, _is_chained=True, **kwargs)
)
credentials.append(InteractiveBrowserCredential(tenant_id=interactive_browser_tenant_id, **kwargs))

super(DefaultAzureCredential, self).__init__(*credentials)

Expand All @@ -226,5 +217,7 @@ def get_token(self, *scopes: str, **kwargs) -> AccessToken:
"%s acquired a token from %s", self.__class__.__name__, self._successful_credential.__class__.__name__
)
return token

return super(DefaultAzureCredential, self).get_token(*scopes, **kwargs)
within_dac.set(True)
token = super(DefaultAzureCredential, self).get_token(*scopes, **kwargs)
within_dac.set(False)
return token
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from azure.core.exceptions import ClientAuthenticationError

from .. import CredentialUnavailableError
from .._internal import resolve_tenant, validate_tenant_id
from .._internal import resolve_tenant, validate_tenant_id, within_dac
from .._internal.decorators import wrap_exceptions
from .._internal.msal_client import MsalClient
from .._internal.shared_token_cache import NO_TOKEN
Expand All @@ -38,7 +38,6 @@ def __init__(
# authenticate in the tenant that produced the record unless "tenant_id" specifies another
self._tenant_id = tenant_id or self._auth_record.tenant_id
validate_tenant_id(self._tenant_id)
self._is_chained = kwargs.pop("_is_chained", False)
self._cache = kwargs.pop("_cache", None)
self._cache_persistence_options = kwargs.pop("cache_persistence_options", None)
self._client_applications: Dict[str, PublicClientApplication] = {}
Expand All @@ -61,7 +60,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
self._initialize()

if not self._cache:
if self._is_chained:
if within_dac.get():
raise CredentialUnavailableError(message="Shared token cache unavailable")
raise ClientAuthenticationError(message="Shared token cache unavailable")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from azure.core.exceptions import ClientAuthenticationError
from .._exceptions import CredentialUnavailableError
from .._constants import AzureAuthorityHosts, AZURE_VSCODE_CLIENT_ID, EnvironmentVariables
from .._internal import normalize_authority, validate_tenant_id
from .._internal import normalize_authority, validate_tenant_id, within_dac
from .._internal.aad_client import AadClient, AadClientBase
from .._internal.get_token_mixin import GetTokenMixin
from .._internal.decorators import log_get_token
Expand All @@ -29,7 +29,6 @@ def __init__(self, **kwargs: Any) -> None:
super(_VSCodeCredentialBase, self).__init__()

user_settings = get_user_settings()
self._is_chained = kwargs.pop("_is_chained", False)
self._cloud = user_settings.get("azure.cloud", "AzureCloud")
self._refresh_token = None
self._unavailable_reason = ""
Expand Down Expand Up @@ -162,7 +161,7 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
" to troubleshoot this issue."
)
raise CredentialUnavailableError(message=error_message)
if self._is_chained:
if within_dac.get():
try:
token = super(VisualStudioCodeCredential, self).get_token(*scopes, **kwargs)
return token
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
resolve_tenant,
validate_tenant_id,
within_credential_chain,
within_dac,
)


Expand Down Expand Up @@ -47,6 +48,7 @@ def _scopes_to_resource(*scopes) -> str:
"normalize_authority",
"resolve_tenant",
"within_credential_chain",
"within_dac",
"wrap_exceptions",
"validate_tenant_id",
]
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def __init__(
self._authority = normalize_authority(authority) if authority else get_default_authority()
environment = urlparse(self._authority).netloc
self._environment_aliases = KNOWN_ALIASES.get(environment) or frozenset((environment,))
self._is_chained = kwargs.pop("_is_chained", False)
self._username = username
self._tenant_id = tenant_id
self._cache = kwargs.pop("_cache", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .._constants import EnvironmentVariables, KnownAuthorities

within_credential_chain = ContextVar("within_credential_chain", default=False)
within_dac = ContextVar("within_dac", default=False)

_LOGGER = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
parse_token,
sanitize_output,
)
from ..._internal import resolve_tenant
from ..._internal import resolve_tenant, within_dac


class AzureDeveloperCliCredential(AsyncContextManager):
Expand Down Expand Up @@ -72,11 +72,9 @@ def __init__(
tenant_id: str = "",
additionally_allowed_tenants: Optional[List[str]] = None,
process_timeout: int = 10,
_is_chained: bool = False,
) -> None:

self.tenant_id = tenant_id
self._is_chained = _is_chained
self._additionally_allowed_tenants = additionally_allowed_tenants or []
self._process_timeout = process_timeout

Expand Down Expand Up @@ -113,7 +111,7 @@ async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:

if tenant:
command += " --tenant-id " + tenant
output = await _run_command(command, self._process_timeout, _is_chained=self._is_chained)
output = await _run_command(command, self._process_timeout)

token = parse_token(output)
if not token:
Expand All @@ -123,7 +121,7 @@ async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
f"To mitigate this issue, please refer to the troubleshooting guidelines here at "
f"https://aka.ms/azsdk/python/identity/azdevclicredential/troubleshoot."
)
if self._is_chained:
if within_dac.get():
raise CredentialUnavailableError(message=message)
raise ClientAuthenticationError(message=message)

Expand All @@ -133,7 +131,7 @@ async def close(self) -> None:
"""Calling this method is unnecessary"""


async def _run_command(command: str, timeout: int, _is_chained: bool = False) -> str:
async def _run_command(command: str, timeout: int) -> str:
# Ensure executable exists in PATH first. This avoids a subprocess call that would fail anyway.
if shutil.which(EXECUTABLE_NAME) is None:
raise CredentialUnavailableError(message=CLI_NOT_FOUND)
Expand Down Expand Up @@ -175,6 +173,6 @@ async def _run_command(command: str, timeout: int, _is_chained: bool = False) ->
raise CredentialUnavailableError(message=NOT_LOGGED_IN)

message = sanitize_output(stderr) if stderr else "Failed to invoke Azure Developer CLI"
if _is_chained:
if within_dac.get():
raise CredentialUnavailableError(message=message)
raise ClientAuthenticationError(message=message)
Loading

0 comments on commit 9f1bf3a

Please sign in to comment.