Skip to content

Commit

Permalink
Resource id adjustments
Browse files Browse the repository at this point in the history
Resource ID as mi_res_id in App Service, as msi_res_id in other flavors
  • Loading branch information
rayluo committed Sep 6, 2024
1 parent 0a756e9 commit 28fbf7c
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 4 deletions.
13 changes: 9 additions & 4 deletions msal/managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ManagedIdentity(UserDict):

_types_mapping = { # Maps type name in configuration to type name on wire
CLIENT_ID: "client_id",
RESOURCE_ID: "mi_res_id",
RESOURCE_ID: "msi_res_id", # VM's IMDS prefers msi_res_id https://github.com/Azure/azure-rest-api-specs/blob/dba6ed1f03bda88ac6884c0a883246446cc72495/specification/imds/data-plane/Microsoft.InstanceMetadataService/stable/2018-10-01/imds.json#L233-L239
OBJECT_ID: "object_id",
}

Expand Down Expand Up @@ -430,9 +430,9 @@ def _obtain_token(http_client, managed_identity, resource):
return _obtain_token_on_azure_vm(http_client, managed_identity, resource)


def _adjust_param(params, managed_identity):
def _adjust_param(params, managed_identity, types_mapping=None):
# Modify the params dict in place
id_name = ManagedIdentity._types_mapping.get(
id_name = (types_mapping or ManagedIdentity._types_mapping).get(
managed_identity.get(ManagedIdentity.ID_TYPE))
if id_name:
params[id_name] = managed_identity[ManagedIdentity.ID]
Expand Down Expand Up @@ -479,7 +479,12 @@ def _obtain_token_on_app_service(
"api-version": "2019-08-01",
"resource": resource,
}
_adjust_param(params, managed_identity)
_adjust_param(params, managed_identity, types_mapping={
ManagedIdentity.CLIENT_ID: "client_id",
ManagedIdentity.RESOURCE_ID: "mi_res_id", # App Service's resource id uses "mi_res_id"
ManagedIdentity.OBJECT_ID: "object_id",
})

resp = http_client.get(
endpoint,
params=params,
Expand Down
32 changes: 32 additions & 0 deletions tests/test_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,22 @@ def test_vm_error_should_be_returned_as_is(self):
json.loads(raw_error), self.app.acquire_token_for_client(resource="R"))
self.assertEqual({}, self.app._token_cache._cache)

def test_vm_resource_id_parameter_should_be_msi_res_id(self):
app = ManagedIdentityClient(
{"ManagedIdentityIdType": "ResourceId", "Id": "1234"},
http_client=requests.Session(),
)
with patch.object(app._http_client, "get", return_value=MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_in": 3600, "resource": "R"}',
)) as mocked_method:
app.acquire_token_for_client(resource="R")
mocked_method.assert_called_with(
'http://169.254.169.254/metadata/identity/oauth2/token',
params={'api-version': '2018-02-01', 'resource': 'R', 'msi_res_id': '1234'},
headers={'Metadata': 'true'},
)


@patch.dict(os.environ, {"IDENTITY_ENDPOINT": "http://localhost", "IDENTITY_HEADER": "foo"})
class AppServiceTestCase(ClientTestCase):
Expand All @@ -164,6 +180,22 @@ def test_app_service_error_should_be_normalized(self):
}, self.app.acquire_token_for_client(resource="R"))
self.assertEqual({}, self.app._token_cache._cache)

def test_app_service_resource_id_parameter_should_be_mi_res_id(self):
app = ManagedIdentityClient(
{"ManagedIdentityIdType": "ResourceId", "Id": "1234"},
http_client=requests.Session(),
)
with patch.object(app._http_client, "get", return_value=MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_on": 12345, "resource": "R"}',
)) as mocked_method:
app.acquire_token_for_client(resource="R")
mocked_method.assert_called_with(
'http://localhost',
params={'api-version': '2019-08-01', 'resource': 'R', 'mi_res_id': '1234'},
headers={'X-IDENTITY-HEADER': 'foo', 'Metadata': 'true'},
)


@patch.dict(os.environ, {"MSI_ENDPOINT": "http://localhost", "MSI_SECRET": "foo"})
class MachineLearningTestCase(ClientTestCase):
Expand Down

0 comments on commit 28fbf7c

Please sign in to comment.