Skip to content

Commit

Permalink
Revise AuthenticationRecord to align with other SDKs (#17689)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Apr 5, 2021
1 parent 2264517 commit 472b1d4
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 19 deletions.
30 changes: 23 additions & 7 deletions sdk/identity/azure-identity/azure/identity/_auth_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,17 @@
import json


SUPPORTED_VERSIONS = {"1.0"}


class AuthenticationRecord(object):
"""A record which can initialize :class:`DeviceCodeCredential` or :class:`InteractiveBrowserCredential`"""
"""Non-secret account information for an authenticated user
This class enables :class:`DeviceCodeCredential` and :class:`InteractiveBrowserCredential` to access
previously cached authentication data. Applications shouldn't construct instances of this class. They should
instead acquire one from a credential's **authenticate** method, such as
:func:`InteractiveBrowserCredential.authenticate`. See the user_authentication sample for more details.
"""

def __init__(self, tenant_id, client_id, authority, home_account_id, username):
# type: (str, str, str, str, str) -> None
Expand Down Expand Up @@ -52,11 +61,17 @@ def deserialize(cls, data):

deserialized = json.loads(data)

version = deserialized.get("version")
if version not in SUPPORTED_VERSIONS:
raise ValueError(
'Unexpected version "{}". This package supports these versions: {}'.format(version, SUPPORTED_VERSIONS)
)

return cls(
authority=deserialized["authority"],
client_id=deserialized["client_id"],
home_account_id=deserialized["home_account_id"],
tenant_id=deserialized["tenant_id"],
client_id=deserialized["clientId"],
home_account_id=deserialized["homeAccountId"],
tenant_id=deserialized["tenantId"],
username=deserialized["username"],
)

Expand All @@ -69,10 +84,11 @@ def serialize(self):

record = {
"authority": self._authority,
"client_id": self._client_id,
"home_account_id": self._home_account_id,
"tenant_id": self._tenant_id,
"clientId": self._client_id,
"homeAccountId": self._home_account_id,
"tenantId": self._tenant_id,
"username": self._username,
"version": "1.0",
}

return json.dumps(record)
56 changes: 44 additions & 12 deletions sdk/identity/azure-identity/tests/test_auth_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,57 @@
import json

from azure.identity import AuthenticationRecord
from azure.identity._auth_record import SUPPORTED_VERSIONS
import pytest


def test_serialization():
"""serialize should accept arbitrary additional key/value pairs, which deserialize should ignore"""

attrs = ("authority", "client_id","home_account_id", "tenant_id", "username")
nums = (n for n in range(len(attrs)))
record_values = {attr: next(nums) for attr in attrs}

record = AuthenticationRecord(**record_values)
expected = {
"authority": "http://localhost",
"clientId": "client-id",
"homeAccountId": "object-id.tenant-id",
"tenantId": "tenant-id",
"username": "user",
"version": "1.0",
}

record = AuthenticationRecord(
expected["tenantId"],
expected["clientId"],
expected["authority"],
expected["homeAccountId"],
expected["username"],
)
serialized = record.serialize()

# AuthenticationRecord's fields should have been serialized
assert json.loads(serialized) == record_values
assert json.loads(serialized) == expected

deserialized = AuthenticationRecord.deserialize(serialized)

# the deserialized record and the constructed record should have the same fields
assert sorted(vars(deserialized)) == sorted(vars(record))

# the constructed and deserialized records should have the same values
assert all(getattr(deserialized, attr) == record_values[attr] for attr in attrs)
assert record.authority == deserialized.authority == expected["authority"]
assert record.client_id == deserialized.client_id == expected["clientId"]
assert record.home_account_id == deserialized.home_account_id == expected["homeAccountId"]
assert record.tenant_id == deserialized.tenant_id == expected["tenantId"]
assert record.username == deserialized.username == expected["username"]


@pytest.mark.parametrize("version", ("42", None))
def test_unknown_version(version):
"""deserialize should raise ValueError when the data doesn't contain a known version"""

data = {
"authority": "http://localhost",
"clientId": "client-id",
"homeAccountId": "object-id.tenant-id",
"tenantId": "tenant-id",
"username": "user",
}

if version:
data["version"] = version

with pytest.raises(ValueError, match=".*{}.*".format(version)) as ex:
AuthenticationRecord.deserialize(json.dumps(data))
assert str(SUPPORTED_VERSIONS) in str(ex.value)

0 comments on commit 472b1d4

Please sign in to comment.