Skip to content

Commit

Permalink
Implementing Telemetry V4
Browse files Browse the repository at this point in the history
Implement Telemetry's app-wide state

Test cases for telemetry id on most public methods

Test telemetry buffer for offline states
  • Loading branch information
rayluo committed Mar 18, 2021
1 parent 9d29158 commit 31b24af
Show file tree
Hide file tree
Showing 4 changed files with 346 additions and 113 deletions.
166 changes: 86 additions & 80 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
import sys
import warnings
import uuid
from threading import Lock

import requests

Expand All @@ -18,6 +18,7 @@
from .wstrust_request import send_request as wst_send_request
from .wstrust_response import *
from .token_cache import TokenCache
import msal.telemetry


# The __init__.py will import this. Not the other way around.
Expand Down Expand Up @@ -52,18 +53,6 @@ def decorate_scope(
decorated = scope_set | reserved_scope
return list(decorated)

CLIENT_REQUEST_ID = 'client-request-id'
CLIENT_CURRENT_TELEMETRY = 'x-client-current-telemetry'

def _get_new_correlation_id():
correlation_id = str(uuid.uuid4())
logger.debug("Generates correlation_id: %s", correlation_id)
return correlation_id


def _build_current_telemetry_request_header(public_api_id, force_refresh=False):
return "1|{},{}|".format(public_api_id, "1" if force_refresh else "0")


def extract_certs(public_cert_content):
# Parses raw public certificate file contents and returns a list of strings
Expand Down Expand Up @@ -257,6 +246,14 @@ def __init__(
self.token_cache = token_cache or TokenCache()
self.client = self._build_client(client_credential, self.authority)
self.authority_groups = None
self._telemetry_buffer = {}
self._telemetry_lock = Lock()

def _build_telemetry_context(
self, api_id, correlation_id=None, refresh_reason=None):
return msal.telemetry._TelemetryContext(
self._telemetry_buffer, self._telemetry_lock, api_id,
correlation_id=correlation_id, refresh_reason=refresh_reason)

def _build_client(self, client_credential, authority):
client_assertion = None
Expand Down Expand Up @@ -520,21 +517,21 @@ def authorize(): # A controller in a web app
return redirect(url_for("index"))
"""
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
return _clean_up(self.client.obtain_token_by_auth_code_flow(
telemetry_context = self._build_telemetry_context(
self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID)
response =_clean_up(self.client.obtain_token_by_auth_code_flow(
auth_code_flow,
auth_response,
scope=decorate_scope(scopes, self.client_id) if scopes else None,
headers={
CLIENT_REQUEST_ID: _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID),
},
headers=telemetry_context.generate_headers(),
data=dict(
kwargs.pop("data", {}),
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities,
auth_code_flow.pop("claims_challenge", None))),
**kwargs))
telemetry_context.update_telemetry(response)
return response

def acquire_token_by_authorization_code(
self,
Expand Down Expand Up @@ -593,20 +590,20 @@ def acquire_token_by_authorization_code(
"Change your acquire_token_by_authorization_code() "
"to acquire_token_by_auth_code_flow()", DeprecationWarning)
with warnings.catch_warnings(record=True):
return _clean_up(self.client.obtain_token_by_authorization_code(
telemetry_context = self._build_telemetry_context(
self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID)
response = _clean_up(self.client.obtain_token_by_authorization_code(
code, redirect_uri=redirect_uri,
scope=decorate_scope(scopes, self.client_id),
headers={
CLIENT_REQUEST_ID: _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID),
},
headers=telemetry_context.generate_headers(),
data=dict(
kwargs.pop("data", {}),
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)),
nonce=nonce,
**kwargs))
telemetry_context.update_telemetry(resposne)
return response

def get_accounts(self, username=None):
"""Get a list of accounts which previously signed in, i.e. exists in cache.
Expand Down Expand Up @@ -735,7 +732,7 @@ def acquire_token_silent(
- None when cache lookup does not yield a token.
"""
result = self.acquire_token_silent_with_error(
scopes, account, authority, force_refresh,
scopes, account, authority=authority, force_refresh=force_refresh,
claims_challenge=claims_challenge, **kwargs)
return result if result and "error" not in result else None

Expand Down Expand Up @@ -780,7 +777,7 @@ def acquire_token_silent_with_error(
"""
assert isinstance(scopes, list), "Invalid parameter type"
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
correlation_id = _get_new_correlation_id()
correlation_id = msal.telemetry._get_new_correlation_id()
if authority:
warnings.warn("We haven't decided how/if this method will accept authority parameter")
# the_authority = Authority(
Expand Down Expand Up @@ -851,9 +848,11 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
target=scopes,
query=query)
now = time.time()
refresh_reason = msal.telemetry.AT_ABSENT
for entry in matches:
expires_in = int(entry["expires_on"]) - now
if expires_in < 5*60: # Then consider it expired
refresh_reason = msal.telemetry.AT_EXPIRED
continue # Removal is not necessary, it will be overwritten
logger.debug("Cache hit an AT")
access_token_from_cache = { # Mimic a real response
Expand All @@ -862,13 +861,18 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
"expires_in": int(expires_in), # OAuth2 specs defines it as int
}
if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging
refresh_reason = msal.telemetry.AT_AGING
break # With a fallback in hand, we break here to go refresh
self._build_telemetry_context(-1).hit_an_access_token()
return access_token_from_cache # It is still good as new
else:
refresh_reason = msal.telemetry.FORCE_REFRESH # TODO: It could also mean claims_challenge
assert refresh_reason, "It should have been established at this point"
try:
result = self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
result = _clean_up(self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
authority, decorate_scope(scopes, self.client_id), account,
force_refresh=force_refresh, claims_challenge=claims_challenge, **kwargs)
result = _clean_up(result)
refresh_reason=refresh_reason, claims_challenge=claims_challenge,
**kwargs))
if (result and "error" not in result) or (not access_token_from_cache):
return result
except: # The exact HTTP exception is transportation-layer dependent
Expand Down Expand Up @@ -922,7 +926,8 @@ def _get_app_metadata(self, environment):
def _acquire_token_silent_by_finding_specific_refresh_token(
self, authority, scopes, query,
rt_remover=None, break_condition=lambda response: False,
force_refresh=False, correlation_id=None, claims_challenge=None, **kwargs):
refresh_reason=None, correlation_id=None, claims_challenge=None,
**kwargs):
matches = self.token_cache.find(
self.token_cache.CredentialType.REFRESH_TOKEN,
# target=scopes, # AAD RTs are scope-independent
Expand All @@ -931,6 +936,9 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
client = self._build_client(self.client_credential, authority)

response = None # A distinguishable value to mean cache is empty
telemetry_context = self._build_telemetry_context(
self.ACQUIRE_TOKEN_SILENT_ID,
correlation_id=correlation_id, refresh_reason=refresh_reason)
for entry in sorted( # Since unfit RTs would not be aggressively removed,
# we start from newer RTs which are more likely fit.
matches,
Expand All @@ -948,16 +956,13 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
skip_account_creation=True, # To honor a concurrent remove_account()
)),
scope=scopes,
headers={
CLIENT_REQUEST_ID: correlation_id or _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_SILENT_ID, force_refresh=force_refresh),
},
headers=telemetry_context.generate_headers(),
data=dict(
kwargs.pop("data", {}),
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)),
**kwargs)
telemetry_context.update_telemetry(response)
if "error" not in response:
return response
logger.debug("Refresh failed. {error}: {error_description}".format(
Expand Down Expand Up @@ -1006,18 +1011,19 @@ def acquire_token_by_refresh_token(self, refresh_token, scopes, **kwargs):
* A dict contains no "error" key means migration was successful.
"""
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
return _clean_up(self.client.obtain_token_by_refresh_token(
telemetry_context = self._build_telemetry_context(
self.ACQUIRE_TOKEN_BY_REFRESH_TOKEN,
refresh_reason=msal.telemetry.FORCE_REFRESH)
response = _clean_up(self.client.obtain_token_by_refresh_token(
refresh_token,
scope=decorate_scope(scopes, self.client_id),
headers={
CLIENT_REQUEST_ID: _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_REFRESH_TOKEN),
},
headers=telemetry_context.generate_headers(),
rt_getter=lambda rt: rt,
on_updating_rt=False,
on_removing_rt=lambda rt_item: None, # No OP
**kwargs))
telemetry_context.update_telemetry(response)
return response


class PublicClientApplication(ClientApplication): # browser app or mobile app
Expand Down Expand Up @@ -1093,7 +1099,9 @@ def acquire_token_interactive(
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
claims = _merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)
return _clean_up(self.client.obtain_token_by_browser(
telemetry_context = self._build_telemetry_context(
self.ACQUIRE_TOKEN_INTERACTIVE)
response = _clean_up(self.client.obtain_token_by_browser(
scope=decorate_scope(scopes, self.client_id) if scopes else None,
extra_scope_to_consent=extra_scopes_to_consent,
redirect_uri="http://localhost:{port}".format(
Expand All @@ -1107,12 +1115,10 @@ def acquire_token_interactive(
"domain_hint": domain_hint,
},
data=dict(kwargs.pop("data", {}), claims=claims),
headers={
CLIENT_REQUEST_ID: _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_INTERACTIVE),
},
headers=telemetry_context.generate_headers(),
**kwargs))
telemetry_context.update_telemetry(response)
return response

def initiate_device_flow(self, scopes=None, **kwargs):
"""Initiate a Device Flow instance,
Expand All @@ -1125,13 +1131,10 @@ def initiate_device_flow(self, scopes=None, **kwargs):
- A successful response would contain "user_code" key, among others
- an error response would contain some other readable key/value pairs.
"""
correlation_id = _get_new_correlation_id()
correlation_id = msal.telemetry._get_new_correlation_id()
flow = self.client.initiate_device_flow(
scope=decorate_scope(scopes or [], self.client_id),
headers={
CLIENT_REQUEST_ID: correlation_id,
# CLIENT_CURRENT_TELEMETRY is not currently required
},
headers={msal.telemetry.CLIENT_REQUEST_ID: correlation_id},
**kwargs)
flow[self.DEVICE_FLOW_CORRELATION_ID] = correlation_id
return flow
Expand All @@ -1155,7 +1158,10 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs):
- A successful response would contain "access_token" key,
- an error response would contain "error" and usually "error_description".
"""
return _clean_up(self.client.obtain_token_by_device_flow(
telemetry_context = self._build_telemetry_context(
self.ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID,
correlation_id=flow.get(self.DEVICE_FLOW_CORRELATION_ID))
response = _clean_up(self.client.obtain_token_by_device_flow(
flow,
data=dict(
kwargs.pop("data", {}),
Expand All @@ -1165,13 +1171,10 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs):
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge),
),
headers={
CLIENT_REQUEST_ID:
flow.get(self.DEVICE_FLOW_CORRELATION_ID) or _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID),
},
headers=telemetry_context.generate_headers(),
**kwargs))
telemetry_context.update_telemetry(response)
return response

def acquire_token_by_username_password(
self, username, password, scopes, claims_challenge=None, **kwargs):
Expand All @@ -1196,28 +1199,30 @@ def acquire_token_by_username_password(
- an error response would contain "error" and usually "error_description".
"""
scopes = decorate_scope(scopes, self.client_id)
headers = {
CLIENT_REQUEST_ID: _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID),
}
telemetry_context = self._build_telemetry_context(
self.ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID)
headers = telemetry_context.generate_headers()
data = dict(
kwargs.pop("data", {}),
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge))
if not self.authority.is_adfs:
user_realm_result = self.authority.user_realm_discovery(
username, correlation_id=headers[CLIENT_REQUEST_ID])
username, correlation_id=headers[msal.telemetry.CLIENT_REQUEST_ID])
if user_realm_result.get("account_type") == "Federated":
return _clean_up(self._acquire_token_by_username_password_federated(
response = _clean_up(self._acquire_token_by_username_password_federated(
user_realm_result, username, password, scopes=scopes,
data=data,
headers=headers, **kwargs))
return _clean_up(self.client.obtain_token_by_username_password(
telemetry_context.update_telemetry(response)
return response
response = _clean_up(self.client.obtain_token_by_username_password(
username, password, scope=scopes,
headers=headers,
data=data,
**kwargs))
telemetry_context.update_telemetry(response)
return response

def _acquire_token_by_username_password_federated(
self, user_realm_result, username, password, scopes=None, **kwargs):
Expand Down Expand Up @@ -1277,18 +1282,18 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
"""
# TBD: force_refresh behavior
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
return _clean_up(self.client.obtain_token_for_client(
telemetry_context = self._build_telemetry_context(
self.ACQUIRE_TOKEN_FOR_CLIENT_ID)
response = _clean_up(self.client.obtain_token_for_client(
scope=scopes, # This grant flow requires no scope decoration
headers={
CLIENT_REQUEST_ID: _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_FOR_CLIENT_ID),
},
headers=telemetry_context.generate_headers(),
data=dict(
kwargs.pop("data", {}),
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)),
**kwargs))
telemetry_context.update_telemetry(response)
return response

def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=None, **kwargs):
"""Acquires token using on-behalf-of (OBO) flow.
Expand Down Expand Up @@ -1316,9 +1321,11 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No
- A successful response would contain "access_token" key,
- an error response would contain "error" and usually "error_description".
"""
telemetry_context = self._build_telemetry_context(
self.ACQUIRE_TOKEN_ON_BEHALF_OF_ID)
# The implementation is NOT based on Token Exchange
# https://tools.ietf.org/html/draft-ietf-oauth-token-exchange-16
return _clean_up(self.client.obtain_token_by_assertion( # bases on assertion RFC 7521
response = _clean_up(self.client.obtain_token_by_assertion( # bases on assertion RFC 7521
user_assertion,
self.client.GRANT_TYPE_JWT, # IDTs and AAD ATs are all JWTs
scope=decorate_scope(scopes, self.client_id), # Decoration is used for:
Expand All @@ -1332,9 +1339,8 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No
requested_token_use="on_behalf_of",
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)),
headers={
CLIENT_REQUEST_ID: _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_ON_BEHALF_OF_ID),
},
headers=telemetry_context.generate_headers(),
**kwargs))
telemetry_context.update_telemetry(response)
return response

Loading

0 comments on commit 31b24af

Please sign in to comment.