Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

App service support for MI #537

Open
wants to merge 17 commits into
base: andyohart/managed-identity
Choose a base branch
from
60 changes: 52 additions & 8 deletions apps/managedidentity/managedidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ const (
himdsExecutableName = "himds.exe"
tokenName = "Tokens"

// App Service
appServiceAPIVersion = "2019-08-01"

// Environment Variables
identityEndpointEnvVar = "IDENTITY_ENDPOINT"
identityHeaderEnvVar = "IDENTITY_HEADER"
Expand Down Expand Up @@ -213,6 +216,7 @@ func New(id ID, options ...ClientOption) (Client, error) {
for _, option := range options {
option(&opts)
}

switch t := id.(type) {
case UserAssignedClientID:
if len(string(t)) == 0 {
Expand Down Expand Up @@ -290,27 +294,40 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac
return ar, err
}
}

switch c.source {
case AzureArc:
return acquireTokenForAzureArc(ctx, c, resource)
case DefaultToIMDS:
return acquireTokenForIMDS(ctx, c, resource)
case AppService:
return acquireTokenForAppService(ctx, c, resource)
default:
return base.AuthResult{}, fmt.Errorf("unsupported source %q", c.source)
}
}

func acquireTokenForIMDS(ctx context.Context, client Client, resource string) (base.AuthResult, error) {
req, err := createIMDSAuthRequest(ctx, client.miType, resource)
func acquireTokenForAppService(ctx context.Context, c Client, resource string) (base.AuthResult, error) {
AndyOHart marked this conversation as resolved.
Show resolved Hide resolved
req, err := createAppServiceAuthRequest(ctx, c.miType, resource, c)
if err != nil {
return base.AuthResult{}, err
}
tokenResponse, err := client.getTokenForRequest(req)
tokenResponse, err := c.getTokenForRequest(req)
if err != nil {
return base.AuthResult{}, err
}
return authResultFromToken(client.authParams, tokenResponse)
return authResultFromToken(c.authParams, tokenResponse)
}

func acquireTokenForIMDS(ctx context.Context, c Client, resource string) (base.AuthResult, error) {
req, err := createIMDSAuthRequest(ctx, c.miType, resource, c)
if err != nil {
return base.AuthResult{}, err
}
tokenResponse, err := c.getTokenForRequest(req)
if err != nil {
return base.AuthResult{}, err
}
return authResultFromToken(c.authParams, tokenResponse)
}

func acquireTokenForAzureArc(ctx context.Context, client Client, resource string) (base.AuthResult, error) {
Expand Down Expand Up @@ -377,8 +394,7 @@ func (c Client) retry(maxRetries int, req *http.Request) (*http.Response, error)
var resp *http.Response
var err error
for attempt := 0; attempt < maxRetries; attempt++ {
tryCtx, tryCancel := context.WithTimeout(req.Context(), time.Second*15)
defer tryCancel()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the problem with deferring this?

tryCtx, tryCancel := context.WithTimeout(req.Context(), time.Second*60)
AndyOHart marked this conversation as resolved.
Show resolved Hide resolved
if resp != nil && resp.Body != nil {
_, _ = io.Copy(io.Discard, resp.Body)
resp.Body.Close()
Expand All @@ -390,14 +406,17 @@ func (c Client) retry(maxRetries int, req *http.Request) (*http.Response, error)
retrylist = retryCodesForIMDS
}
if err == nil && !contains(retrylist, resp.StatusCode) {
tryCancel()
4gust marked this conversation as resolved.
Show resolved Hide resolved
return resp, nil
}
select {
case <-time.After(time.Second):
case <-req.Context().Done():
err = req.Context().Err()
tryCancel()
return resp, err
}
tryCancel()
}
return resp, err
}
Expand Down Expand Up @@ -446,7 +465,32 @@ func (c Client) getTokenForRequest(req *http.Request) (accesstokens.TokenRespons
return r, err
}

func createIMDSAuthRequest(ctx context.Context, id ID, resource string) (*http.Request, error) {
func createAppServiceAuthRequest(ctx context.Context, id ID, resource string, c Client) (*http.Request, error) {
identityEndpoint := os.Getenv(identityEndpointEnvVar)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, identityEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("X-IDENTITY-HEADER", os.Getenv(identityHeaderEnvVar))
q := req.URL.Query()
q.Set("api-version", appServiceAPIVersion)
q.Set("resource", resource)
switch t := id.(type) {
case UserAssignedClientID:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are general purpose params, not related to App Service. So instead of logging them here, can we instead have logging statements at the start of the MSI "AcquireToken" request, which display all config values?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And similarly, after the HTTP request is done ... log failures and in case of success log things like "got an access token", expires in etc. Don't log the actual access token.

q.Set(miQueryParameterClientId, string(t))
case UserAssignedResourceID:
q.Set(miQueryParameterResourceId, string(t))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
q.Set(miQueryParameterResourceId, string(t))
q.Set("mi_res_id", string(t))

Copy link
Collaborator Author

@4gust 4gust Dec 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change only for App service ?
AzureAD/microsoft-authentication-library-for-dotnet#4911
This bug was created to change this to "msi_res_id"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. That issue is about supporting ACI, which imitates IMDS and requires msi_res_id, in accordance with IMDS docs. App Service requires mi_res_id. Every platform has its own managed identity implementation, so subtle differences like this are normal

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were hoping to use msi_res_id everywhere. Seems like AppService supports it. Thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If they support it, they should document it; we shouldn't depend on undocumented behavior.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will test this for on Azure Function with all the different user assigned ways and update this comment on the findings.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some testing,
I Created a App service function as assigned a UserAssigned managed Identity and a SystemAssigned managed identity.
Then I tried to access a KeyVault using the token we acquired.
Created 4 different client one for each type

SystemAssgiened() -- No Querry parameter value in the request.
UserAssignedClientID("") -- client_id client id was assigned to this parameter
UserAssignedObjectID("") -- object_id Object id was assigned to this parameter
UserAssignedResourceID("") -- msi_res_id Resource id was assigned to this parameter

Outcome -
I was able to get AccessToken (verified the token on the jwt.ms) from all the types of the managed identities, using that token I was able to get the secret that was stored in the KeyVault

Copy link
Member

@bgavrilMS bgavrilMS Dec 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tend to agree with @chlowell on this one though. The documented parameter for App Service (and for all other sources except IMDS I believe) is mi_res_id - see https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#rest-endpoint-reference

It'd be safer to use that and use msi_res_id only for IMDS.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will update with this in mind

case UserAssignedObjectID:
q.Set(miQueryParameterObjectId, string(t))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
q.Set(miQueryParameterObjectId, string(t))
q.Set("principal_id", string(t))

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see this in .net or other msal's
Is this changed for App service only?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but it appears I'm out of date on this point, as the docs now state the API will accept object_id as an alias for principal_id. 🤷 I'd still make this change because Azure SDK uses, and tests, principal_id, and these docs have been incorrect before (i.e. if you want to keep using object_id, make sure you test it)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested with object_id and it worked, would you recommend that I also test with principal_id?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should run live tests before merging in any case

case systemAssignedValue:
default:
return nil, fmt.Errorf("unsupported type %T", id)
}
req.URL.RawQuery = q.Encode()
return req, nil
}

func createIMDSAuthRequest(ctx context.Context, id ID, resource string, c Client) (*http.Request, error) {
msiEndpoint, err := url.Parse(imdsDefaultEndpoint)
if err != nil {
return nil, fmt.Errorf("couldn't parse %q: %s", imdsDefaultEndpoint, err)
Expand Down
93 changes: 93 additions & 0 deletions apps/managedidentity/managedidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,99 @@ func TestIMDSAcquireTokenReturnsTokenSuccess(t *testing.T) {
}
}

func TestAppServiceAcquireTokenReturnsTokenSuccess(t *testing.T) {
setEnvVars(t, AppService)
testCases := []struct {
resource string
miType ID
}{
{resource: resource, miType: SystemAssigned()},
{resource: resourceDefaultSuffix, miType: SystemAssigned()},
{resource: resource, miType: UserAssignedClientID("clientId")},
{resource: resourceDefaultSuffix, miType: UserAssignedResourceID("resourceId")},
{resource: resourceDefaultSuffix, miType: UserAssignedObjectID("objectId")},
}
for _, testCase := range testCases {
t.Run(string(DefaultToIMDS)+"-"+testCase.miType.value(), func(t *testing.T) {
endpoint := "http://127.0.0.1:41564/msi/token"

var localUrl *url.URL
mockClient := mock.Client{}
responseBody, err := getSuccessfulResponse(resource)
if err != nil {
t.Fatalf(errorFormingJsonResponse, err.Error())
}

mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK),
mock.WithBody(responseBody),
mock.WithCallback(func(r *http.Request) {
localUrl = r.URL
}))
// resetting cache
before := cacheManager
defer func() { cacheManager = before }()
cacheManager = storage.New(nil)

client, err := New(testCase.miType, WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
result, err := client.AcquireToken(context.Background(), testCase.resource)
if err != nil {
t.Fatal(err)
}
if localUrl == nil || !strings.HasPrefix(localUrl.String(), endpoint) {
t.Fatalf("url request is not on %s got %s", endpoint, localUrl)
}
query := localUrl.Query()

if query.Get(apiVersionQueryParameterName) != appServiceAPIVersion {
t.Fatalf("api-version not on %s got %s", appServiceAPIVersion, query.Get(apiVersionQueryParameterName))
}
if query.Get(resourceQueryParameterName) != strings.TrimSuffix(testCase.resource, "/.default") {
4gust marked this conversation as resolved.
Show resolved Hide resolved
t.Fatal("suffix /.default was not removed.")
}
switch i := testCase.miType.(type) {
case UserAssignedClientID:
if query.Get(miQueryParameterClientId) != i.value() {
t.Fatalf("resource client-id is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterClientId))
}
case UserAssignedResourceID:
if query.Get(miQueryParameterResourceId) != i.value() {
t.Fatalf("resource resource-id is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterResourceId))
}
case UserAssignedObjectID:
if query.Get(miQueryParameterObjectId) != i.value() {
t.Fatalf("resource objectid is incorrect, wanted %s got %s", i.value(), query.Get(miQueryParameterObjectId))
}
}
if result.Metadata.TokenSource != base.IdentityProvider {
t.Fatalf("expected IndenityProvider tokensource, got %d", result.Metadata.TokenSource)
}
if result.AccessToken != token {
t.Fatalf("wanted %q, got %q", token, result.AccessToken)
}
result, err = client.AcquireToken(context.Background(), testCase.resource)
if err != nil {
t.Fatal(err)
}
if result.Metadata.TokenSource != base.Cache {
t.Fatalf("wanted cache token source, got %d", result.Metadata.TokenSource)
}
secondFakeClient, err := New(testCase.miType, WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
result, err = secondFakeClient.AcquireToken(context.Background(), testCase.resource)
if err != nil {
t.Fatal(err)
}
if result.Metadata.TokenSource != base.Cache {
t.Fatalf("cache result wanted cache token source, got %d", result.Metadata.TokenSource)
}
})
}
}
func TestAzureArc(t *testing.T) {
testCaseFilePath := filepath.Join(t.TempDir(), azureConnectedMachine)

Expand Down