Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NewManagedIdentityCredential returns an error for unsupported ID options #23267

Merged
merged 3 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@
### Features Added

### Breaking Changes
* `NewManagedIdentityCredential` now returns an error when a user-assigned identity
is specified on a platform whose managed identity API doesn't support that.
`ManagedIdentityCredential.GetToken()` formerly logged a warning in these cases.
Returning an error instead prevents the credential authenticating an unexpected
identity, causing a client to act with unexpected privileges. The affected
platforms are:
* Azure Arc
* Azure ML (when a resource ID is specified; client IDs are supported)
* Cloud Shell
* Service Fabric

### Bugs Fixed

Expand Down
25 changes: 25 additions & 0 deletions sdk/azidentity/default_azure_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package azidentity

import (
"context"
"errors"
"fmt"
"net/http"
"os"
Expand Down Expand Up @@ -371,3 +372,27 @@ func TestDefaultAzureCredential_IMDS(t *testing.T) {
require.Equal(t, tokenValue, tk.Token)
})
}

func TestDefaultAzureCredential_UnsupportedMIClientID(t *testing.T) {
fail := true
before := defaultAzTokenProvider
defer func() { defaultAzTokenProvider = before }()
defaultAzTokenProvider = func(ctx context.Context, scopes []string, tenant, subscription string) ([]byte, error) {
if fail {
return nil, errors.New("fail")
}
return mockAzTokenProviderSuccess(ctx, scopes, tenant, subscription)
}
t.Setenv(azureClientID, fakeClientID)
t.Setenv(msiEndpoint, fakeMIEndpoint)

cred, err := NewDefaultAzureCredential(nil)
require.NoError(t, err, "an unsupported client ID isn't a constructor error")

_, err = cred.GetToken(ctx, testTRO)
require.ErrorContains(t, err, "Cloud Shell", "error should mention the unsupported ID")

fail = false
_, err = cred.GetToken(ctx, testTRO)
require.NoError(t, err, "expected a token from AzureCLICredential")
}
53 changes: 19 additions & 34 deletions sdk/azidentity/managed_identity_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) (*manag
if endpoint, ok := os.LookupEnv(identityEndpoint); ok {
if _, ok := os.LookupEnv(identityHeader); ok {
if _, ok := os.LookupEnv(identityServerThumbprint); ok {
if options.ID != nil {
return nil, errors.New("the Service Fabric API doesn't support specifying a user-assigned managed identity at runtime")
}
env = "Service Fabric"
c.endpoint = endpoint
c.msiType = msiTypeServiceFabric
Expand All @@ -152,16 +155,25 @@ func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) (*manag
c.msiType = msiTypeAppService
}
} else if _, ok := os.LookupEnv(arcIMDSEndpoint); ok {
if options.ID != nil {
return nil, errors.New("the Azure Arc API doesn't support specifying a user-assigned managed identity at runtime")
}
env = "Azure Arc"
c.endpoint = endpoint
c.msiType = msiTypeAzureArc
}
} else if endpoint, ok := os.LookupEnv(msiEndpoint); ok {
c.endpoint = endpoint
if _, ok := os.LookupEnv(msiSecret); ok {
if options.ID != nil && options.ID.idKind() == miResourceID {
return nil, errors.New("the Azure ML API doesn't support specifying a managed identity by resource ID")
}
env = "Azure ML"
c.msiType = msiTypeAzureML
} else {
if options.ID != nil {
return nil, errors.New("the Cloud Shell API doesn't support user-assigned managed identities")
}
env = "Cloud Shell"
c.msiType = msiTypeCloudShell
}
Expand Down Expand Up @@ -304,13 +316,13 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, id Manage
msg := fmt.Sprintf("failed to retreive secret key from the identity endpoint: %v", err)
return nil, newAuthenticationFailedError(credNameManagedIdentity, msg, nil, err)
}
return c.createAzureArcAuthRequest(ctx, id, scopes, key)
return c.createAzureArcAuthRequest(ctx, scopes, key)
case msiTypeAzureML:
return c.createAzureMLAuthRequest(ctx, id, scopes)
case msiTypeServiceFabric:
return c.createServiceFabricAuthRequest(ctx, id, scopes)
return c.createServiceFabricAuthRequest(ctx, scopes)
case msiTypeCloudShell:
return c.createCloudShellAuthRequest(ctx, id, scopes)
return c.createCloudShellAuthRequest(ctx, scopes)
default:
return nil, newCredentialUnavailableError(credNameManagedIdentity, "managed identity isn't supported in this environment")
}
Expand Down Expand Up @@ -368,9 +380,7 @@ func (c *managedIdentityClient) createAzureMLAuthRequest(ctx context.Context, id
q.Add("clientid", os.Getenv(defaultIdentityClientID))
if id != nil {
if id.idKind() == miResourceID {
log.Write(EventAuthentication, "WARNING: Azure ML doesn't support specifying a managed identity by resource ID")
q.Set("clientid", "")
q.Set(miResID, id.String())
return nil, newAuthenticationFailedError(credNameManagedIdentity, "Azure ML doesn't support specifying a managed identity by resource ID", nil, nil)
} else {
q.Set("clientid", id.String())
}
Expand All @@ -379,7 +389,7 @@ func (c *managedIdentityClient) createAzureMLAuthRequest(ctx context.Context, id
return request, nil
}

func (c *managedIdentityClient) createServiceFabricAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
func (c *managedIdentityClient) createServiceFabricAuthRequest(ctx context.Context, scopes []string) (*policy.Request, error) {
request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
if err != nil {
return nil, err
Expand All @@ -389,14 +399,6 @@ func (c *managedIdentityClient) createServiceFabricAuthRequest(ctx context.Conte
request.Raw().Header.Set("Secret", os.Getenv(identityHeader))
q.Add("api-version", serviceFabricAPIVersion)
q.Add("resource", strings.Join(scopes, " "))
if id != nil {
log.Write(EventAuthentication, "WARNING: Service Fabric doesn't support selecting a user-assigned identity at runtime")
if id.idKind() == miResourceID {
q.Add(miResID, id.String())
} else {
q.Add(qpClientID, id.String())
}
}
request.Raw().URL.RawQuery = q.Encode()
return request, nil
}
Expand Down Expand Up @@ -453,7 +455,7 @@ func (c *managedIdentityClient) getAzureArcSecretKey(ctx context.Context, resour
return string(key), nil
}

func (c *managedIdentityClient) createAzureArcAuthRequest(ctx context.Context, id ManagedIDKind, resources []string, key string) (*policy.Request, error) {
func (c *managedIdentityClient) createAzureArcAuthRequest(ctx context.Context, resources []string, key string) (*policy.Request, error) {
request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
if err != nil {
return nil, err
Expand All @@ -463,19 +465,11 @@ func (c *managedIdentityClient) createAzureArcAuthRequest(ctx context.Context, i
q := request.Raw().URL.Query()
q.Add("api-version", azureArcAPIVersion)
q.Add("resource", strings.Join(resources, " "))
if id != nil {
log.Write(EventAuthentication, "WARNING: Azure Arc doesn't support user-assigned managed identities")
if id.idKind() == miResourceID {
q.Add(miResID, id.String())
} else {
q.Add(qpClientID, id.String())
}
}
request.Raw().URL.RawQuery = q.Encode()
return request, nil
}

func (c *managedIdentityClient) createCloudShellAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
func (c *managedIdentityClient) createCloudShellAuthRequest(ctx context.Context, scopes []string) (*policy.Request, error) {
request, err := azruntime.NewRequest(ctx, http.MethodPost, c.endpoint)
if err != nil {
return nil, err
Expand All @@ -488,14 +482,5 @@ func (c *managedIdentityClient) createCloudShellAuthRequest(ctx context.Context,
if err := request.SetBody(body, "application/x-www-form-urlencoded"); err != nil {
return nil, err
}
if id != nil {
log.Write(EventAuthentication, "WARNING: Cloud Shell doesn't support user-assigned managed identities")
q := request.Raw().URL.Query()
if id.idKind() == miResourceID {
q.Add(miResID, id.String())
} else {
q.Add(qpClientID, id.String())
}
}
return request, nil
}
61 changes: 0 additions & 61 deletions sdk/azidentity/managed_identity_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
)

Expand Down Expand Up @@ -123,63 +122,3 @@ func TestManagedIdentityClient_IMDSErrors(t *testing.T) {
})
}
}

func TestManagedIdentityClient_UserAssignedIDWarning(t *testing.T) {
for _, test := range []struct {
name string
createRequest func(*managedIdentityClient) error
}{
{
name: "Azure Arc",
createRequest: func(client *managedIdentityClient) error {
_, err := client.createAzureArcAuthRequest(context.Background(), client.id, []string{liveTestScope}, "key")
return err
},
},
{
name: "Cloud Shell",
createRequest: func(client *managedIdentityClient) error {
_, err := client.createCloudShellAuthRequest(context.Background(), client.id, []string{liveTestScope})
return err
},
},
{
name: "Service Fabric",
createRequest: func(client *managedIdentityClient) error {
_, err := client.createServiceFabricAuthRequest(context.Background(), client.id, []string{liveTestScope})
return err
},
},
} {
for _, id := range []ManagedIDKind{ClientID(fakeClientID), ResourceID(fakeResourceID)} {
s := "-ClientID"
if id.String() == fakeResourceID {
s = "-ResourceID"
}
t.Run(test.name+s, func(t *testing.T) {
msgs := []string{}
log.SetListener(func(event log.Event, msg string) {
if event == EventAuthentication {
msgs = append(msgs, msg)
}
})
client, err := newManagedIdentityClient(&ManagedIdentityCredentialOptions{
ID: id,
})
if err != nil {
t.Fatal(err)
}
err = test.createRequest(client)
if err != nil {
t.Fatal(err)
}
for _, msg := range msgs {
if strings.Contains(msg, test.name) && strings.Contains(msg, "user-assigned") {
return
}
}
t.Fatalf("expected warning about user-assigned ID, got:\n%s", strings.Join(msgs, "\n"))
})
}
}
}
22 changes: 17 additions & 5 deletions sdk/azidentity/managed_identity_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ type ManagedIDKind interface {
idKind() managedIdentityIDKind
}

// ClientID is the client ID of a user-assigned managed identity.
// ClientID is the client ID of a user-assigned managed identity. NewManagedIdentityCredential
// returns an error when a ClientID is specified on the following platforms:
//
// - Azure Arc
// - Cloud Shell
// - Service Fabric
type ClientID string

func (ClientID) idKind() managedIdentityIDKind {
Expand All @@ -44,7 +49,13 @@ func (c ClientID) String() string {
return string(c)
}

// ResourceID is the resource ID of a user-assigned managed identity.
// ResourceID is the resource ID of a user-assigned managed identity. NewManagedIdentityCredential
// returns an error when a ResourceID is specified on the following platforms:
//
// - Azure Arc
// - Azure ML
// - Cloud Shell
// - Service Fabric
type ResourceID string

func (ResourceID) idKind() managedIdentityIDKind {
Expand All @@ -60,9 +71,10 @@ func (r ResourceID) String() string {
type ManagedIdentityCredentialOptions struct {
azcore.ClientOptions

// ID is the ID of a managed identity the credential should authenticate. Set this field to use a specific identity
// instead of the hosting environment's default. The value may be the identity's client ID or resource ID, but note that
// some platforms don't accept resource IDs.
// ID of a managed identity the credential should authenticate. Set this field to use a specific identity instead of
// the hosting environment's default. The value may be the identity's client ID or resource ID.
// NewManagedIdentityCredential returns an error when the hosting environment doesn't support user-assigned managed
// identities, or the specified kind of ID.
chlowell marked this conversation as resolved.
Show resolved Hide resolved
ID ManagedIDKind

// dac indicates whether the credential is part of DefaultAzureCredential. When true, and the environment doesn't have
Expand Down
35 changes: 35 additions & 0 deletions sdk/azidentity/managed_identity_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -673,3 +673,38 @@ func TestManagedIdentityCredential_ServiceFabric(t *testing.T) {
}
testGetTokenSuccess(t, cred)
}

func TestManagedIdentityCredential_UnsupportedID(t *testing.T) {
chlowell marked this conversation as resolved.
Show resolved Hide resolved
t.Run("Azure Arc", func(t *testing.T) {
t.Setenv(identityEndpoint, fakeMIEndpoint)
t.Setenv(arcIMDSEndpoint, fakeMIEndpoint)
_, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ResourceID(fakeResourceID)})
require.Error(t, err)
_, err = NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ClientID(fakeClientID)})
require.Error(t, err)
})
t.Run("Azure ML", func(t *testing.T) {
t.Setenv(msiEndpoint, fakeMIEndpoint)
t.Setenv(msiSecret, "...")
_, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ResourceID(fakeResourceID)})
require.Error(t, err)
_, err = NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ClientID(fakeClientID)})
require.NoError(t, err)
})
t.Run("Cloud Shell", func(t *testing.T) {
t.Setenv(msiEndpoint, fakeMIEndpoint)
_, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ResourceID(fakeResourceID)})
require.Error(t, err)
_, err = NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ClientID(fakeClientID)})
require.Error(t, err)
})
t.Run("Service Fabric", func(t *testing.T) {
t.Setenv(identityEndpoint, fakeMIEndpoint)
t.Setenv(identityHeader, "...")
t.Setenv(identityServerThumbprint, "...")
_, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ResourceID(fakeResourceID)})
require.Error(t, err)
_, err = NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ClientID(fakeClientID)})
require.Error(t, err)
})
}