Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for storage managed disk challenge #20418

Merged
merged 10 commits into from
Mar 27, 2023
1 change: 1 addition & 0 deletions sdk/storage/azblob/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Features Added

* Added [Blob Batch API](https://learn.microsoft.com/rest/api/storageservices/blob-batch).
* Added support for bearer challenge for identity based managed disks.

### Breaking Changes

Expand Down
2 changes: 1 addition & 1 deletion sdk/storage/azblob/appendblob/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type Client base.CompositeClient[generated.BlobClient, generated.AppendBlobClien
// - cred - an Azure AD credential, typically obtained via the azidentity module
// - options - client options; pass nil to accept the default values
func NewClient(blobURL string, cred azcore.TokenCredential, options *ClientOptions) (*Client, error) {
authPolicy := runtime.NewBearerTokenPolicy(cred, []string{shared.TokenScope}, nil)
vibhansa-msft marked this conversation as resolved.
Show resolved Hide resolved
authPolicy := shared.NewStorageChallengePolicy(cred)
jhendrixMSFT marked this conversation as resolved.
Show resolved Hide resolved
conOptions := shared.GetClientOptions(options)
conOptions.PerRetryPolicies = append(conOptions.PerRetryPolicies, authPolicy)
pl := runtime.NewPipeline(exported.ModuleName,
Expand Down
2 changes: 1 addition & 1 deletion sdk/storage/azblob/blob/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type Client base.Client[generated.BlobClient]
// - cred - an Azure AD credential, typically obtained via the azidentity module
// - options - client options; pass nil to accept the default values
func NewClient(blobURL string, cred azcore.TokenCredential, options *ClientOptions) (*Client, error) {
authPolicy := runtime.NewBearerTokenPolicy(cred, []string{shared.TokenScope}, nil)
authPolicy := shared.NewStorageChallengePolicy(cred)
conOptions := shared.GetClientOptions(options)
conOptions.PerRetryPolicies = append(conOptions.PerRetryPolicies, authPolicy)
pl := runtime.NewPipeline(exported.ModuleName, exported.ModuleVersion, runtime.PipelineOptions{}, &conOptions.ClientOptions)
Expand Down
2 changes: 1 addition & 1 deletion sdk/storage/azblob/blockblob/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type Client base.CompositeClient[generated.BlobClient, generated.BlockBlobClient
// - cred - an Azure AD credential, typically obtained via the azidentity module
// - options - client options; pass nil to accept the default values
func NewClient(blobURL string, cred azcore.TokenCredential, options *ClientOptions) (*Client, error) {
authPolicy := runtime.NewBearerTokenPolicy(cred, []string{shared.TokenScope}, nil)
authPolicy := shared.NewStorageChallengePolicy(cred)
conOptions := shared.GetClientOptions(options)
conOptions.PerRetryPolicies = append(conOptions.PerRetryPolicies, authPolicy)
pl := runtime.NewPipeline(exported.ModuleName, exported.ModuleVersion, runtime.PipelineOptions{}, &conOptions.ClientOptions)
Expand Down
4 changes: 2 additions & 2 deletions sdk/storage/azblob/container/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type Client base.Client[generated.ContainerClient]
// - cred - an Azure AD credential, typically obtained via the azidentity module
// - options - client options; pass nil to accept the default values
func NewClient(containerURL string, cred azcore.TokenCredential, options *ClientOptions) (*Client, error) {
authPolicy := runtime.NewBearerTokenPolicy(cred, []string{shared.TokenScope}, nil)
authPolicy := shared.NewStorageChallengePolicy(cred)
conOptions := shared.GetClientOptions(options)
conOptions.PerRetryPolicies = append(conOptions.PerRetryPolicies, authPolicy)
pl := runtime.NewPipeline(exported.ModuleName, exported.ModuleVersion, runtime.PipelineOptions{}, &conOptions.ClientOptions)
Expand Down Expand Up @@ -351,7 +351,7 @@ func (c *Client) NewBatchBuilder() (*BatchBuilder, error) {

switch cred := c.credential().(type) {
case *azcore.TokenCredential:
authPolicy = runtime.NewBearerTokenPolicy(*cred, []string{shared.TokenScope}, nil)
authPolicy = shared.NewStorageChallengePolicy(*cred)
case *SharedKeyCredential:
authPolicy = exported.NewSharedKeyCredPolicy(cred)
case nil:
Expand Down
114 changes: 114 additions & 0 deletions sdk/storage/azblob/internal/shared/challenge_policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
//go:build go1.18
// +build go1.18

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.

package shared

import (
"errors"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"net/http"
"strings"
)

type storageAuthorizer struct {
scopes []string
tenantID string
}

func NewStorageChallengePolicy(cred azcore.TokenCredential) policy.Policy {
s := storageAuthorizer{scopes: []string{TokenScope}}
return runtime.NewBearerTokenPolicy(cred, []string{TokenScope}, &policy.BearerTokenOptions{
AuthorizationHandler: policy.AuthorizationHandler{
OnRequest: s.onRequest,
OnChallenge: s.onChallenge,
},
})
}

func (s *storageAuthorizer) onRequest(req *policy.Request, authNZ func(policy.TokenRequestOptions) error) error {
if len(s.scopes) == 0 || s.tenantID == "" {
// returning nil indicates the bearer token policy should send the request
return nil
}
return authNZ(policy.TokenRequestOptions{Scopes: s.scopes})
}

func (s *storageAuthorizer) onChallenge(req *policy.Request, resp *http.Response, authNZ func(policy.TokenRequestOptions) error) error {
// parse the challenge
err := s.parseChallenge(resp)
if err != nil {
return err
}
// TODO: Set tenantID when policy.TokenRequestOptions supports it. https://github.com/Azure/azure-sdk-for-go/issues/19841
return authNZ(policy.TokenRequestOptions{Scopes: s.scopes})
}

type challengePolicyError struct {
err error
}

func (c *challengePolicyError) Error() string {
return c.err.Error()
}

func (*challengePolicyError) NonRetriable() {
// marker method
}

func (c *challengePolicyError) Unwrap() error {
return c.err
}

// parses Tenant ID from auth challenge
// https://login.microsoftonline.com/00000000-0000-0000-0000-000000000000/oauth2/authorize
func parseTenant(url string) string {
if url == "" {
return ""
}
parts := strings.Split(url, "/")
tenant := parts[3]
souravgupta-msft marked this conversation as resolved.
Show resolved Hide resolved
tenant = strings.ReplaceAll(tenant, ",", "")
return tenant
}

func (s *storageAuthorizer) parseChallenge(resp *http.Response) error {
authHeader := resp.Header.Get("WWW-Authenticate")
if authHeader == "" {
return &challengePolicyError{err: errors.New("response has no WWW-Authenticate header for challenge authentication")}
}

// Strip down to auth and resource
// Format is "Bearer authorization_uri=\"<site>\" resource_id=\"<site>\""
authHeader = strings.ReplaceAll(authHeader, "Bearer ", "")

parts := strings.Split(authHeader, " ")

vals := map[string]string{}
for _, part := range parts {
subParts := strings.Split(part, "=")
if len(subParts) == 2 {
stripped := strings.ReplaceAll(subParts[1], "\"", "")
stripped = strings.TrimSuffix(stripped, ",")
vals[subParts[0]] = stripped
}
}

s.tenantID = parseTenant(vals["authorization_uri"])

scope := vals["resource_id"]
if scope == "" {
return &challengePolicyError{err: errors.New("could not find a valid resource in the WWW-Authenticate header")}
}

if !strings.HasSuffix(scope, "/.default") {
scope = strings.TrimSuffix(scope, "/") // Resource might come back with /
scope += "/.default"
}
s.scopes = []string{scope}
return nil
}
84 changes: 84 additions & 0 deletions sdk/storage/azblob/internal/shared/challenge_policy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
//go:build go1.18
// +build go1.18

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.

package shared

import (
"context"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
"github.com/stretchr/testify/require"
"net/http"
"strings"
"testing"
"time"
)

type credentialFunc func(context.Context, policy.TokenRequestOptions) (azcore.AccessToken, error)

func (cf credentialFunc) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) {
return cf(ctx, options)
}

func TestChallengePolicy(t *testing.T) {
accessToken := "***"
storageResource := "https://storage.azure.com"
storageScope := "https://storage.azure.com/.default"
challenge := `Bearer authorization_uri="https://login.microsoftonline.com/{tenant}", resource_id="{storageResource}"`
diskResource := "https://disk.azure.com/"
diskScope := "https://disk.azure.com/.default"

for _, test := range []struct {
expectedScope, format, resource string
}{
{format: challenge, resource: storageResource, expectedScope: storageScope},
{format: challenge, resource: diskResource, expectedScope: diskScope},
} {
t.Run("", func(t *testing.T) {
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.AppendResponse(
mock.WithHeader("WWW-Authenticate", strings.ReplaceAll(test.format, "{storageResource}", test.resource)),
mock.WithStatusCode(401),
)
srv.AppendResponse(mock.WithPredicate(func(r *http.Request) bool {
if authz := r.Header.Values("Authorization"); len(authz) != 1 || authz[0] != "Bearer "+accessToken {
t.Errorf(`unexpected Authorization "%s"`, authz)
}
return true
}))
srv.AppendResponse()
authenticated := false
cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
authenticated = true
require.Equal(t, []string{test.expectedScope}, tro.Scopes)
return azcore.AccessToken{Token: accessToken, ExpiresOn: time.Now().Add(time.Hour)}, nil
})
p := NewStorageChallengePolicy(cred)
pl := runtime.NewPipeline("", "",
runtime.PipelineOptions{PerRetry: []policy.Policy{p}},
&policy.ClientOptions{Transport: srv},
)
req, err := runtime.NewRequest(context.Background(), "GET", "https://localhost")
require.NoError(t, err)
_, err = pl.Do(req)
require.NoError(t, err)
require.True(t, authenticated, "policy should have authenticated")
})
}
}

func TestParseTenant(t *testing.T) {
actual := parseTenant("")
require.Empty(t, actual)

expected := "00000000-0000-0000-0000-000000000000"
sampleURL := "https://login.microsoftonline.com/" + expected
actual = parseTenant(sampleURL)
require.Equal(t, expected, actual, "tenant was not properly parsed")
}
2 changes: 1 addition & 1 deletion sdk/storage/azblob/pageblob/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type Client base.CompositeClient[generated.BlobClient, generated.PageBlobClient]
// - cred - an Azure AD credential, typically obtained via the azidentity module
// - options - client options; pass nil to accept the default values
func NewClient(blobURL string, cred azcore.TokenCredential, options *ClientOptions) (*Client, error) {
authPolicy := runtime.NewBearerTokenPolicy(cred, []string{shared.TokenScope}, nil)
authPolicy := shared.NewStorageChallengePolicy(cred)
conOptions := shared.GetClientOptions(options)
conOptions.PerRetryPolicies = append(conOptions.PerRetryPolicies, authPolicy)
pl := runtime.NewPipeline(exported.ModuleName, exported.ModuleVersion, runtime.PipelineOptions{}, &conOptions.ClientOptions)
Expand Down
4 changes: 2 additions & 2 deletions sdk/storage/azblob/service/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type Client base.Client[generated.ServiceClient]
// - cred - an Azure AD credential, typically obtained via the azidentity module
// - options - client options; pass nil to accept the default values
func NewClient(serviceURL string, cred azcore.TokenCredential, options *ClientOptions) (*Client, error) {
authPolicy := runtime.NewBearerTokenPolicy(cred, []string{shared.TokenScope}, nil)
authPolicy := shared.NewStorageChallengePolicy(cred)
conOptions := shared.GetClientOptions(options)
conOptions.PerRetryPolicies = append(conOptions.PerRetryPolicies, authPolicy)
pl := runtime.NewPipeline(exported.ModuleName, exported.ModuleVersion, runtime.PipelineOptions{}, &conOptions.ClientOptions)
Expand Down Expand Up @@ -303,7 +303,7 @@ func (s *Client) NewBatchBuilder() (*BatchBuilder, error) {

switch cred := s.credential().(type) {
case *azcore.TokenCredential:
authPolicy = runtime.NewBearerTokenPolicy(*cred, []string{shared.TokenScope}, nil)
authPolicy = shared.NewStorageChallengePolicy(*cred)
case *SharedKeyCredential:
authPolicy = exported.NewSharedKeyCredPolicy(cred)
case nil:
Expand Down