Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

App service support for MI #537

Open
wants to merge 17 commits into
base: andyohart/managed-identity
Choose a base branch
from
28 changes: 26 additions & 2 deletions apps/internal/oauth/ops/accesstokens/accesstokens_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,29 @@ func TestTokenResponseUnmarshal(t *testing.T) {
}`,
want: TokenResponse{
AccessToken: "secret",
ExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)},
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Second * 86399)},
ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)},
GrantedScopes: Scopes{Slice: []string{"openid", "profile"}},
ClientInfo: ClientInfo{
UID: "uid",
UTID: "utid",
},
},
jwtDecoder: jwtDecoderFake,
},
{
desc: "Success",
payload: fmt.Sprintf(`
{
"access_token": "secret",
"expires_on": %d,
"ext_expires_in": 86399,
"client_info": {"uid": "uid","utid": "utid"},
"scope": "openid profile"
}`, time.Now().Add(time.Hour).Unix()),
want: TokenResponse{
AccessToken: "secret",
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)},
GrantedScopes: Scopes{Slice: []string{"openid", "profile"}},
ClientInfo: ClientInfo{
Expand All @@ -795,7 +817,9 @@ func TestTokenResponseUnmarshal(t *testing.T) {
case err != nil:
continue
}

if got.ExpiresOn.T.Unix() != test.want.ExpiresOn.T.Unix() {
t.Errorf("TestCreateTokenResponse: got %v, want %v", got.ExpiresOn.T.Unix(), test.want.ExpiresOn.T.Unix())
}
// Note: IncludeUnexported prevents minor differences in time.Time due to internal fields.
if diff := (&pretty.Config{IncludeUnexported: false}).Compare(test.want, got); diff != "" {
t.Errorf("TestCreateTokenResponse: -want/+got:\n%s", diff)
Expand Down
60 changes: 59 additions & 1 deletion apps/internal/oauth/ops/accesstokens/tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ type TokenResponse struct {
FamilyID string `json:"foci"`
IDToken IDToken `json:"id_token"`
ClientInfo ClientInfo `json:"client_info"`
ExpiresOn internalTime.DurationTime `json:"expires_in"`
ExpiresOn internalTime.DurationTime `json:"-"`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DurationTime exists to calculate ExpiresOn from expires_in. That doesn't help now that you want to consider a possible value of expires_on as well, and you're disabling unmarshaling for this field anyway. So, why not make the type here simply time.Time?

ExtExpiresOn internalTime.DurationTime `json:"ext_expires_in"`
GrantedScopes Scopes `json:"scope"`
DeclinedScopes []string // This is derived
Expand All @@ -183,6 +183,64 @@ type TokenResponse struct {
scopesComputed bool
}

func (tr *TokenResponse) UnmarshalJSON(data []byte) error {
type Alias TokenResponse
aux := &struct {
ExpiresIn json.Number `json:"expires_in"`
ExpiresOn json.Number `json:"expires_on"`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could simplify this method by continuing to use internalTime.DurationTime to calculate ExpiresOn from expires_in:

Suggested change
ExpiresIn json.Number `json:"expires_in"`
ExpiresOn json.Number `json:"expires_on"`
ExpiresOn json.Number `json:"expires_on,omitempty"`
ExpiresOnCalculated internalTime.DurationTime `json:"expires_in,omitempty"`

then after unmarshaling, if aux.ExpiresOn is nonzero, you can convert it to a time.Time for the TokenResponse. If that field is zero, ExpiresOnCalculated.T is the time.Time you want

*Alias
}{
Alias: (*Alias)(tr),
}

// Unmarshal the JSON data into the aux struct
if err := json.Unmarshal(data, &aux); err != nil {
return err
}

// Helper function to parse JSON number into int64
parseDuration := func(num json.Number) (int64, error) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this better than DurationTime.UnmarshalJSON i.e., what do you gain by unmarshaling expires_in to json.Number instead of DurationTime?

if num == "" {
return 0, nil
}
return num.Int64()
}

// Try to parse ExpiresIn first, then fallback to ExpiresOn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be the other way around--prefer expires_on if the response includes it, otherwise calculate from expires_in?

if duration, err := parseDuration(aux.ExpiresIn); err != nil {
println("122121@@")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👀

return err
} else if duration > 0 {
println("122121@")

tr.ExpiresOn = internalTime.DurationTime{T: time.Now().Add(time.Duration(duration) * time.Second)}
} else if duration == 0 || aux.ExpiresOn != "" {
println("122121@@@@@")
// If ExpiresIn is zero, check ExpiresOn
if duration, err := parseDuration(aux.ExpiresOn); err != nil {
println("122121@@@@@!")

return err
} else if duration > 0 {
println("122121@@@@@!!")

tr.ExpiresOn = internalTime.DurationTime{T: time.Unix(duration, 0)}
println(tr.ExpiresOn.T.String())

} else {
println("122121@@@@@!!!!!")

return errors.New("expires_in and expires_on are both missing or invalid")
}
} else {
println("122121")

return errors.New("expires_in or expires_on must be present in the response")
}

return nil
}

// ComputeScope computes the final scopes based on what was granted by the server and
// what our AuthParams were from the authority server. Per OAuth spec, if no scopes are returned, the response should be treated as if all scopes were granted
// This behavior can be observed in client assertion flows, but can happen at any time, this check ensures we treat
Expand Down
77 changes: 60 additions & 17 deletions apps/managedidentity/managedidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ const (
wwwAuthenticateHeaderName = "www-authenticate"

// UAMI query parameter name
miQueryParameterClientId = "client_id"
miQueryParameterObjectId = "object_id"
miQueryParameterResourceId = "msi_res_id"
miQueryParameterClientId = "client_id"
miQueryParameterObjectId = "object_id"
miQueryParameterResourceIdIMDS = "msi_res_id"
miQueryParameterResourceId = "mi_res_id"

// IMDS
imdsDefaultEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
Expand All @@ -66,6 +67,9 @@ const (
himdsExecutableName = "himds.exe"
tokenName = "Tokens"

// App Service
appServiceAPIVersion = "2019-08-01"

// Environment Variables
identityEndpointEnvVar = "IDENTITY_ENDPOINT"
identityHeaderEnvVar = "IDENTITY_HEADER"
Expand Down Expand Up @@ -213,6 +217,7 @@ func New(id ID, options ...ClientOption) (Client, error) {
for _, option := range options {
option(&opts)
}

switch t := id.(type) {
case UserAssignedClientID:
if len(string(t)) == 0 {
Expand Down Expand Up @@ -290,36 +295,49 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac
return ar, err
}
}

switch c.source {
case AzureArc:
return acquireTokenForAzureArc(ctx, c, resource)
return c.acquireTokenForAzureArc(ctx, resource)
case DefaultToIMDS:
return acquireTokenForIMDS(ctx, c, resource)
return c.acquireTokenForIMDS(ctx, resource)
case AppService:
return c.acquireTokenForAppService(ctx, resource)
default:
return base.AuthResult{}, fmt.Errorf("unsupported source %q", c.source)
}
}

func acquireTokenForIMDS(ctx context.Context, client Client, resource string) (base.AuthResult, error) {
req, err := createIMDSAuthRequest(ctx, client.miType, resource)
func (c Client) acquireTokenForAppService(ctx context.Context, resource string) (base.AuthResult, error) {
req, err := createAppServiceAuthRequest(ctx, c.miType, resource)
if err != nil {
return base.AuthResult{}, err
}
tokenResponse, err := client.getTokenForRequest(req)
tokenResponse, err := c.getTokenForRequest(req)
if err != nil {
return base.AuthResult{}, err
}
return authResultFromToken(client.authParams, tokenResponse)
return authResultFromToken(c.authParams, tokenResponse)
}

func acquireTokenForAzureArc(ctx context.Context, client Client, resource string) (base.AuthResult, error) {
func (c Client) acquireTokenForIMDS(ctx context.Context, resource string) (base.AuthResult, error) {
req, err := createIMDSAuthRequest(ctx, c.miType, resource)
if err != nil {
return base.AuthResult{}, err
}
tokenResponse, err := c.getTokenForRequest(req)
if err != nil {
return base.AuthResult{}, err
}
return authResultFromToken(c.authParams, tokenResponse)
}

func (c Client) acquireTokenForAzureArc(ctx context.Context, resource string) (base.AuthResult, error) {
req, err := createAzureArcAuthRequest(ctx, resource, "")
if err != nil {
return base.AuthResult{}, err
}

response, err := client.httpClient.Do(req)
response, err := c.httpClient.Do(req)
if err != nil {
return base.AuthResult{}, err
}
Expand All @@ -329,7 +347,7 @@ func acquireTokenForAzureArc(ctx context.Context, client Client, resource string
return base.AuthResult{}, fmt.Errorf("expected a 401 response, received %d", response.StatusCode)
}

secret, err := client.getAzureArcSecretKey(response, runtime.GOOS)
secret, err := c.getAzureArcSecretKey(response, runtime.GOOS)
if err != nil {
return base.AuthResult{}, err
}
Expand All @@ -339,11 +357,11 @@ func acquireTokenForAzureArc(ctx context.Context, client Client, resource string
return base.AuthResult{}, err
}

tokenResponse, err := client.getTokenForRequest(secondRequest)
tokenResponse, err := c.getTokenForRequest(secondRequest)
if err != nil {
return base.AuthResult{}, err
}
return authResultFromToken(client.authParams, tokenResponse)
return authResultFromToken(c.authParams, tokenResponse)
}

func authResultFromToken(authParams authority.AuthParams, token accesstokens.TokenResponse) (base.AuthResult, error) {
Expand Down Expand Up @@ -377,7 +395,7 @@ func (c Client) retry(maxRetries int, req *http.Request) (*http.Response, error)
var resp *http.Response
var err error
for attempt := 0; attempt < maxRetries; attempt++ {
tryCtx, tryCancel := context.WithTimeout(req.Context(), time.Second*15)
tryCtx, tryCancel := context.WithTimeout(req.Context(), time.Minute)
defer tryCancel()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the problem with deferring this?

if resp != nil && resp.Body != nil {
_, _ = io.Copy(io.Discard, resp.Body)
Expand Down Expand Up @@ -446,6 +464,31 @@ func (c Client) getTokenForRequest(req *http.Request) (accesstokens.TokenRespons
return r, err
}

func createAppServiceAuthRequest(ctx context.Context, id ID, resource string) (*http.Request, error) {
identityEndpoint := os.Getenv(identityEndpointEnvVar)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, identityEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("X-IDENTITY-HEADER", os.Getenv(identityHeaderEnvVar))
q := req.URL.Query()
q.Set("api-version", appServiceAPIVersion)
q.Set("resource", resource)
switch t := id.(type) {
case UserAssignedClientID:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are general purpose params, not related to App Service. So instead of logging them here, can we instead have logging statements at the start of the MSI "AcquireToken" request, which display all config values?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And similarly, after the HTTP request is done ... log failures and in case of success log things like "got an access token", expires in etc. Don't log the actual access token.

q.Set(miQueryParameterClientId, string(t))
case UserAssignedResourceID:
q.Set(miQueryParameterResourceId, string(t))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
q.Set(miQueryParameterResourceId, string(t))
q.Set("mi_res_id", string(t))

Copy link
Collaborator Author

@4gust 4gust Dec 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change only for App service ?
AzureAD/microsoft-authentication-library-for-dotnet#4911
This bug was created to change this to "msi_res_id"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. That issue is about supporting ACI, which imitates IMDS and requires msi_res_id, in accordance with IMDS docs. App Service requires mi_res_id. Every platform has its own managed identity implementation, so subtle differences like this are normal

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were hoping to use msi_res_id everywhere. Seems like AppService supports it. Thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If they support it, they should document it; we shouldn't depend on undocumented behavior.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will test this for on Azure Function with all the different user assigned ways and update this comment on the findings.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some testing,
I Created a App service function as assigned a UserAssigned managed Identity and a SystemAssigned managed identity.
Then I tried to access a KeyVault using the token we acquired.
Created 4 different client one for each type

SystemAssgiened() -- No Querry parameter value in the request.
UserAssignedClientID("") -- client_id client id was assigned to this parameter
UserAssignedObjectID("") -- object_id Object id was assigned to this parameter
UserAssignedResourceID("") -- msi_res_id Resource id was assigned to this parameter

Outcome -
I was able to get AccessToken (verified the token on the jwt.ms) from all the types of the managed identities, using that token I was able to get the secret that was stored in the KeyVault

Copy link
Member

@bgavrilMS bgavrilMS Dec 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tend to agree with @chlowell on this one though. The documented parameter for App Service (and for all other sources except IMDS I believe) is mi_res_id - see https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#rest-endpoint-reference

It'd be safer to use that and use msi_res_id only for IMDS.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will update with this in mind

case UserAssignedObjectID:
q.Set(miQueryParameterObjectId, string(t))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
q.Set(miQueryParameterObjectId, string(t))
q.Set("principal_id", string(t))

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see this in .net or other msal's
Is this changed for App service only?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but it appears I'm out of date on this point, as the docs now state the API will accept object_id as an alias for principal_id. 🤷 I'd still make this change because Azure SDK uses, and tests, principal_id, and these docs have been incorrect before (i.e. if you want to keep using object_id, make sure you test it)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested with object_id and it worked, would you recommend that I also test with principal_id?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should run live tests before merging in any case

case systemAssignedValue:
default:
return nil, fmt.Errorf("unsupported type %T", id)
}
req.URL.RawQuery = q.Encode()
return req, nil
}

func createIMDSAuthRequest(ctx context.Context, id ID, resource string) (*http.Request, error) {
msiEndpoint, err := url.Parse(imdsDefaultEndpoint)
if err != nil {
Expand All @@ -459,7 +502,7 @@ func createIMDSAuthRequest(ctx context.Context, id ID, resource string) (*http.R
case UserAssignedClientID:
msiParameters.Set(miQueryParameterClientId, string(t))
case UserAssignedResourceID:
msiParameters.Set(miQueryParameterResourceId, string(t))
msiParameters.Set(miQueryParameterResourceIdIMDS, string(t))
case UserAssignedObjectID:
msiParameters.Set(miQueryParameterObjectId, string(t))
case systemAssignedValue: // not adding anything
Expand Down
Loading
Loading