Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

App service support for MI #537

Merged
merged 25 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions apps/internal/oauth/ops/accesstokens/accesstokens_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,29 @@ func TestTokenResponseUnmarshal(t *testing.T) {
}`,
want: TokenResponse{
AccessToken: "secret",
ExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)},
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Second * 86399)},
ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)},
GrantedScopes: Scopes{Slice: []string{"openid", "profile"}},
ClientInfo: ClientInfo{
UID: "uid",
UTID: "utid",
},
},
jwtDecoder: jwtDecoderFake,
},
{
desc: "Success",
payload: fmt.Sprintf(`
{
"access_token": "secret",
"expires_on": %d,
"ext_expires_in": 86399,
"client_info": {"uid": "uid","utid": "utid"},
"scope": "openid profile"
}`, time.Now().Add(time.Hour).Unix()),
want: TokenResponse{
AccessToken: "secret",
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)},
GrantedScopes: Scopes{Slice: []string{"openid", "profile"}},
ClientInfo: ClientInfo{
Expand All @@ -795,7 +817,9 @@ func TestTokenResponseUnmarshal(t *testing.T) {
case err != nil:
continue
}

if got.ExpiresOn.T.Unix() != test.want.ExpiresOn.T.Unix() {
t.Errorf("TestCreateTokenResponse: got %v, want %v", got.ExpiresOn.T.Unix(), test.want.ExpiresOn.T.Unix())
}
// Note: IncludeUnexported prevents minor differences in time.Time due to internal fields.
if diff := (&pretty.Config{IncludeUnexported: false}).Compare(test.want, got); diff != "" {
t.Errorf("TestCreateTokenResponse: -want/+got:\n%s", diff)
Expand Down
60 changes: 59 additions & 1 deletion apps/internal/oauth/ops/accesstokens/tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ type TokenResponse struct {
FamilyID string `json:"foci"`
IDToken IDToken `json:"id_token"`
ClientInfo ClientInfo `json:"client_info"`
ExpiresOn internalTime.DurationTime `json:"expires_in"`
ExpiresOn internalTime.DurationTime `json:"-"`
4gust marked this conversation as resolved.
Show resolved Hide resolved
ExtExpiresOn internalTime.DurationTime `json:"ext_expires_in"`
GrantedScopes Scopes `json:"scope"`
DeclinedScopes []string // This is derived
Expand All @@ -183,6 +183,64 @@ type TokenResponse struct {
scopesComputed bool
}

func (tr *TokenResponse) UnmarshalJSON(data []byte) error {
type Alias TokenResponse
aux := &struct {
ExpiresIn json.Number `json:"expires_in"`
ExpiresOn json.Number `json:"expires_on"`
4gust marked this conversation as resolved.
Show resolved Hide resolved
*Alias
}{
Alias: (*Alias)(tr),
}

// Unmarshal the JSON data into the aux struct
if err := json.Unmarshal(data, &aux); err != nil {
return err
}

// Helper function to parse JSON number into int64
parseDuration := func(num json.Number) (int64, error) {
4gust marked this conversation as resolved.
Show resolved Hide resolved
if num == "" {
return 0, nil
}
return num.Int64()
}

// Try to parse ExpiresIn first, then fallback to ExpiresOn
4gust marked this conversation as resolved.
Show resolved Hide resolved
if duration, err := parseDuration(aux.ExpiresIn); err != nil {
println("122121@@")
4gust marked this conversation as resolved.
Show resolved Hide resolved
return err
} else if duration > 0 {
println("122121@")

tr.ExpiresOn = internalTime.DurationTime{T: time.Now().Add(time.Duration(duration) * time.Second)}
} else if duration == 0 || aux.ExpiresOn != "" {
println("122121@@@@@")
// If ExpiresIn is zero, check ExpiresOn
if duration, err := parseDuration(aux.ExpiresOn); err != nil {
println("122121@@@@@!")

return err
} else if duration > 0 {
println("122121@@@@@!!")

tr.ExpiresOn = internalTime.DurationTime{T: time.Unix(duration, 0)}
println(tr.ExpiresOn.T.String())

} else {
println("122121@@@@@!!!!!")

return errors.New("expires_in and expires_on are both missing or invalid")
}
} else {
println("122121")

return errors.New("expires_in or expires_on must be present in the response")
}

return nil
}

// ComputeScope computes the final scopes based on what was granted by the server and
// what our AuthParams were from the authority server. Per OAuth spec, if no scopes are returned, the response should be treated as if all scopes were granted
// This behavior can be observed in client assertion flows, but can happen at any time, this check ensures we treat
Expand Down
9 changes: 3 additions & 6 deletions apps/managedidentity/managedidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ const (
wwwAuthenticateHeaderName = "www-authenticate"

// UAMI query parameter name
miQueryParameterClientId = "client_id"
miQueryParameterObjectId = "object_id"
miQueryParameterClientId = "client_id"
miQueryParameterObjectId = "object_id"
miQueryParameterResourceIdIMDS = "msi_res_id"
miQueryParameterResourceId = "mi_res_id"
miQueryParameterResourceId = "mi_res_id"

// IMDS
imdsDefaultEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
Expand Down Expand Up @@ -408,17 +408,14 @@ func (c Client) retry(maxRetries int, req *http.Request) (*http.Response, error)
retrylist = retryCodesForIMDS
}
if err == nil && !contains(retrylist, resp.StatusCode) {
tryCancel()
return resp, nil
}
select {
case <-time.After(time.Second):
case <-req.Context().Done():
err = req.Context().Err()
tryCancel()
return resp, err
}
tryCancel()
}
return resp, err
}
Expand Down
42 changes: 27 additions & 15 deletions apps/managedidentity/managedidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ const (
type SuccessfulResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in"`
ExpiresOn int64 `json:"expires_on"`
Resource string `json:"resource"`
TokenType string `json:"token_type"`
}
Expand All @@ -47,14 +48,26 @@ type ErrorResponse struct {
Desc string `json:"error_description"`
}

func getSuccessfulResponse(resource string) ([]byte, error) {
duration := 10 * time.Minute
expiresIn := duration.Seconds()
response := SuccessfulResponse{
AccessToken: token,
ExpiresIn: int64(expiresIn),
Resource: resource,
TokenType: "Bearer",
func getSuccessfulResponse(resource string, doesHaveExpireIn bool) ([]byte, error) {
var response SuccessfulResponse
if doesHaveExpireIn {
4gust marked this conversation as resolved.
Show resolved Hide resolved
duration := 10 * time.Minute
expiresIn := duration.Seconds()
response = SuccessfulResponse{
AccessToken: token,
ExpiresIn: int64(expiresIn),
Resource: resource,
TokenType: "Bearer",
}
} else {
println(time.Now().Add(time.Hour).Unix())
response = SuccessfulResponse{
AccessToken: token,
ExpiresOn: time.Now().Add(time.Hour).Unix(),
Resource: resource,
TokenType: "Bearer",
}
println(response.ExpiresOn)
}
jsonResponse, err := json.Marshal(response)
return jsonResponse, err
Expand Down Expand Up @@ -278,7 +291,7 @@ func Test_RetryPolicy_For_AcquireToken(t *testing.T) {
}))
}
if !testCase.expectedFail {
successRespBody, err := getSuccessfulResponse(resource)
successRespBody, err := getSuccessfulResponse(resource, true)
if err != nil {
t.Fatalf("error while forming json response : %s", err.Error())
}
Expand Down Expand Up @@ -380,7 +393,7 @@ func TestIMDSAcquireTokenReturnsTokenSuccess(t *testing.T) {

var localUrl *url.URL
mockClient := mock.Client{}
responseBody, err := getSuccessfulResponse(resource)
responseBody, err := getSuccessfulResponse(resource, true)
if err != nil {
t.Fatalf(errorFormingJsonResponse, err.Error())
}
Expand Down Expand Up @@ -467,16 +480,15 @@ func TestAppServiceAcquireTokenReturnsTokenSuccess(t *testing.T) {
{resource: resourceDefaultSuffix, miType: UserAssignedObjectID("objectId")},
}
for _, testCase := range testCases {
t.Run(string(DefaultToIMDS)+"-"+testCase.miType.value(), func(t *testing.T) {
t.Run(string(AppService)+"-"+testCase.miType.value(), func(t *testing.T) {
endpoint := "http://127.0.0.1:41564/msi/token"

var localUrl *url.URL
mockClient := mock.Client{}
responseBody, err := getSuccessfulResponse(resource)
responseBody, err := getSuccessfulResponse(resource, false)
if err != nil {
t.Fatalf(errorFormingJsonResponse, err.Error())
}

mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK),
mock.WithBody(responseBody),
mock.WithCallback(func(r *http.Request) {
Expand Down Expand Up @@ -571,7 +583,7 @@ func TestAzureArc(t *testing.T) {
localUrl = r.URL
}))

responseBody, err := getSuccessfulResponse(resource)
responseBody, err := getSuccessfulResponse(resource, true)
if err != nil {
t.Fatalf(errorFormingJsonResponse, err.Error())
}
Expand Down Expand Up @@ -742,7 +754,7 @@ func TestAzureArcErrors(t *testing.T) {
mock.WithHTTPHeader(headers),
)

responseBody, err := getSuccessfulResponse(resource)
responseBody, err := getSuccessfulResponse(resource, true)
if err != nil {
t.Fatalf(errorFormingJsonResponse, err.Error())
}
Expand Down
Loading