From 32f5e82d395b9bdcb1dd7bf1728551c06bd9c335 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Thu, 17 Oct 2024 09:00:31 -0700 Subject: [PATCH] BearerTokenPolicy rewinds bodies before retrying (#23597) --- sdk/azcore/CHANGELOG.md | 6 +- sdk/azcore/internal/shared/constants.go | 2 +- sdk/azcore/runtime/policy_bearer_token.go | 13 +++-- .../runtime/policy_bearer_token_test.go | 57 +++++++++++++++++++ 4 files changed, 68 insertions(+), 10 deletions(-) diff --git a/sdk/azcore/CHANGELOG.md b/sdk/azcore/CHANGELOG.md index 607d992fd8de..f88b277ab632 100644 --- a/sdk/azcore/CHANGELOG.md +++ b/sdk/azcore/CHANGELOG.md @@ -1,16 +1,14 @@ # Release History -## 1.15.1 (Unreleased) +## 1.16.0 (2024-10-17) ### Features Added * Added field `Kind` to `runtime.StartSpanOptions` to allow a kind to be set when starting a span. -### Breaking Changes - ### Bugs Fixed -### Other Changes +* `BearerTokenPolicy` now rewinds request bodies before retrying ## 1.15.0 (2024-10-14) diff --git a/sdk/azcore/internal/shared/constants.go b/sdk/azcore/internal/shared/constants.go index a1512f54aa29..9f53770e5b69 100644 --- a/sdk/azcore/internal/shared/constants.go +++ b/sdk/azcore/internal/shared/constants.go @@ -40,5 +40,5 @@ const ( Module = "azcore" // Version is the semantic version (see http://semver.org) of this module. - Version = "v1.15.1" + Version = "v1.16.0" ) diff --git a/sdk/azcore/runtime/policy_bearer_token.go b/sdk/azcore/runtime/policy_bearer_token.go index 7ed66be3807f..b26db920b092 100644 --- a/sdk/azcore/runtime/policy_bearer_token.go +++ b/sdk/azcore/runtime/policy_bearer_token.go @@ -142,14 +142,17 @@ func (b *BearerTokenPolicy) handleChallenge(req *policy.Request, res *http.Respo tro.Claims = caeChallenge.params["claims"] return b.authenticateAndAuthorize(req)(tro) } - err = b.authzHandler.OnRequest(req, authNZ) - if err == nil { - res, err = req.Next() + if err = b.authzHandler.OnRequest(req, authNZ); err == nil { + if err = req.RewindBody(); err == nil { + res, err = req.Next() + } } case b.authzHandler.OnChallenge != nil && !recursed: if err = b.authzHandler.OnChallenge(req, res, b.authenticateAndAuthorize(req)); err == nil { - if res, err = req.Next(); err == nil { - res, err = b.handleChallenge(req, res, true) + if err = req.RewindBody(); err == nil { + if res, err = req.Next(); err == nil { + res, err = b.handleChallenge(req, res, true) + } } } else { // don't retry challenge handling errors diff --git a/sdk/azcore/runtime/policy_bearer_token_test.go b/sdk/azcore/runtime/policy_bearer_token_test.go index 917dec3400d6..9bda9fa5594d 100644 --- a/sdk/azcore/runtime/policy_bearer_token_test.go +++ b/sdk/azcore/runtime/policy_bearer_token_test.go @@ -7,6 +7,7 @@ import ( "context" "encoding/base64" "fmt" + "io" "strings" "errors" @@ -17,6 +18,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" "github.com/stretchr/testify/require" @@ -531,6 +533,61 @@ func TestBearerTokenPolicy_RequiresHTTPS(t *testing.T) { require.ErrorAs(t, err, &nre) } +func TestBearerTokenPolicy_RewindsBeforeRetry(t *testing.T) { + const expected = "expected" + for _, test := range []struct { + challenge, desc string + onChallenge bool + }{ + { + desc: "CAE challenge", + challenge: `Bearer error="insufficient_claims", claims="ey=="`, + }, + { + desc: "non-CAE challenge", + challenge: `Bearer authorization_uri="https://login.windows.net/", error="invalid_token"`, + onChallenge: true, + }, + } { + t.Run(test.desc, func(t *testing.T) { + read := func(r *http.Request) bool { + actual, err := io.ReadAll(r.Body) + require.NoError(t, err, "request should have body content") + require.EqualValues(t, expected, actual) + return true + } + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse( + mock.WithHeader(shared.HeaderWWWAuthenticate, test.challenge), + mock.WithPredicate(read), + mock.WithStatusCode(http.StatusUnauthorized), + ) + srv.AppendResponse() + srv.AppendResponse(mock.WithPredicate(read)) + srv.AppendResponse() + + called := false + o := &policy.BearerTokenOptions{} + if test.onChallenge { + o.AuthorizationHandler.OnChallenge = func(*policy.Request, *http.Response, func(policy.TokenRequestOptions) error) error { + called = true + return nil + } + } + b := NewBearerTokenPolicy(mockCredential{}, []string{scope}, o) + pl := newTestPipeline(&policy.ClientOptions{PerRetryPolicies: []policy.Policy{b}, Transport: srv}) + req, err := NewRequest(context.Background(), http.MethodPost, srv.URL()) + require.NoError(t, err) + require.NoError(t, req.SetBody(streaming.NopCloser(strings.NewReader(expected)), "text/plain")) + + _, err = pl.Do(req) + require.NoError(t, err) + require.Equal(t, test.onChallenge, called, "policy should call OnChallenge when set") + }) + } +} + func TestCheckHTTPSForAuth(t *testing.T) { req, err := NewRequest(context.Background(), http.MethodGet, "http://contoso.com") require.NoError(t, err)