From 84d213cff168fcf348252490ee9f950ecfef394f Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Tue, 6 Aug 2024 10:51:11 -0700 Subject: [PATCH] NewManagedIdentityCredential returns an error for unsupported ID options (#23267) --- sdk/azidentity/CHANGELOG.md | 10 +++ .../default_azure_credential_test.go | 25 ++++++++ sdk/azidentity/managed_identity_client.go | 53 ++++++---------- .../managed_identity_client_test.go | 61 ------------------- sdk/azidentity/managed_identity_credential.go | 22 +++++-- .../managed_identity_credential_test.go | 35 +++++++++++ 6 files changed, 106 insertions(+), 100 deletions(-) diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index 7f3ae0e25944..6d65da7c790d 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -5,6 +5,16 @@ ### Features Added ### Breaking Changes +* `NewManagedIdentityCredential` now returns an error when a user-assigned identity + is specified on a platform whose managed identity API doesn't support that. + `ManagedIdentityCredential.GetToken()` formerly logged a warning in these cases. + Returning an error instead prevents the credential authenticating an unexpected + identity, causing a client to act with unexpected privileges. The affected + platforms are: + * Azure Arc + * Azure ML (when a resource ID is specified; client IDs are supported) + * Cloud Shell + * Service Fabric ### Bugs Fixed diff --git a/sdk/azidentity/default_azure_credential_test.go b/sdk/azidentity/default_azure_credential_test.go index 0581832cf5a4..f8dd3b6d806c 100644 --- a/sdk/azidentity/default_azure_credential_test.go +++ b/sdk/azidentity/default_azure_credential_test.go @@ -8,6 +8,7 @@ package azidentity import ( "context" + "errors" "fmt" "io" "net/http" @@ -395,3 +396,27 @@ func TestDefaultAzureCredential_IMDS(t *testing.T) { require.Equal(t, tokenValue, tk.Token) }) } + +func TestDefaultAzureCredential_UnsupportedMIClientID(t *testing.T) { + fail := true + before := defaultAzTokenProvider + defer func() { defaultAzTokenProvider = before }() + defaultAzTokenProvider = func(ctx context.Context, scopes []string, tenant, subscription string) ([]byte, error) { + if fail { + return nil, errors.New("fail") + } + return mockAzTokenProviderSuccess(ctx, scopes, tenant, subscription) + } + t.Setenv(azureClientID, fakeClientID) + t.Setenv(msiEndpoint, fakeMIEndpoint) + + cred, err := NewDefaultAzureCredential(nil) + require.NoError(t, err, "an unsupported client ID isn't a constructor error") + + _, err = cred.GetToken(ctx, testTRO) + require.ErrorContains(t, err, "Cloud Shell", "error should mention the unsupported ID") + + fail = false + _, err = cred.GetToken(ctx, testTRO) + require.NoError(t, err, "expected a token from AzureCLICredential") +} diff --git a/sdk/azidentity/managed_identity_client.go b/sdk/azidentity/managed_identity_client.go index b7504c34a2f0..e0dad3db20ea 100644 --- a/sdk/azidentity/managed_identity_client.go +++ b/sdk/azidentity/managed_identity_client.go @@ -143,6 +143,9 @@ func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) (*manag if endpoint, ok := os.LookupEnv(identityEndpoint); ok { if _, ok := os.LookupEnv(identityHeader); ok { if _, ok := os.LookupEnv(identityServerThumbprint); ok { + if options.ID != nil { + return nil, errors.New("the Service Fabric API doesn't support specifying a user-assigned managed identity at runtime") + } env = "Service Fabric" c.endpoint = endpoint c.msiType = msiTypeServiceFabric @@ -152,6 +155,9 @@ func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) (*manag c.msiType = msiTypeAppService } } else if _, ok := os.LookupEnv(arcIMDSEndpoint); ok { + if options.ID != nil { + return nil, errors.New("the Azure Arc API doesn't support specifying a user-assigned managed identity at runtime") + } env = "Azure Arc" c.endpoint = endpoint c.msiType = msiTypeAzureArc @@ -159,9 +165,15 @@ func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) (*manag } else if endpoint, ok := os.LookupEnv(msiEndpoint); ok { c.endpoint = endpoint if _, ok := os.LookupEnv(msiSecret); ok { + if options.ID != nil && options.ID.idKind() == miResourceID { + return nil, errors.New("the Azure ML API doesn't support specifying a managed identity by resource ID") + } env = "Azure ML" c.msiType = msiTypeAzureML } else { + if options.ID != nil { + return nil, errors.New("the Cloud Shell API doesn't support user-assigned managed identities") + } env = "Cloud Shell" c.msiType = msiTypeCloudShell } @@ -314,13 +326,13 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, id Manage msg := fmt.Sprintf("failed to retreive secret key from the identity endpoint: %v", err) return nil, newAuthenticationFailedError(credNameManagedIdentity, msg, nil, err) } - return c.createAzureArcAuthRequest(ctx, id, scopes, key) + return c.createAzureArcAuthRequest(ctx, scopes, key) case msiTypeAzureML: return c.createAzureMLAuthRequest(ctx, id, scopes) case msiTypeServiceFabric: - return c.createServiceFabricAuthRequest(ctx, id, scopes) + return c.createServiceFabricAuthRequest(ctx, scopes) case msiTypeCloudShell: - return c.createCloudShellAuthRequest(ctx, id, scopes) + return c.createCloudShellAuthRequest(ctx, scopes) default: return nil, newCredentialUnavailableError(credNameManagedIdentity, "managed identity isn't supported in this environment") } @@ -378,9 +390,7 @@ func (c *managedIdentityClient) createAzureMLAuthRequest(ctx context.Context, id q.Add("clientid", os.Getenv(defaultIdentityClientID)) if id != nil { if id.idKind() == miResourceID { - log.Write(EventAuthentication, "WARNING: Azure ML doesn't support specifying a managed identity by resource ID") - q.Set("clientid", "") - q.Set(miResID, id.String()) + return nil, newAuthenticationFailedError(credNameManagedIdentity, "Azure ML doesn't support specifying a managed identity by resource ID", nil, nil) } else { q.Set("clientid", id.String()) } @@ -389,7 +399,7 @@ func (c *managedIdentityClient) createAzureMLAuthRequest(ctx context.Context, id return request, nil } -func (c *managedIdentityClient) createServiceFabricAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) { +func (c *managedIdentityClient) createServiceFabricAuthRequest(ctx context.Context, scopes []string) (*policy.Request, error) { request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint) if err != nil { return nil, err @@ -399,14 +409,6 @@ func (c *managedIdentityClient) createServiceFabricAuthRequest(ctx context.Conte request.Raw().Header.Set("Secret", os.Getenv(identityHeader)) q.Add("api-version", serviceFabricAPIVersion) q.Add("resource", strings.Join(scopes, " ")) - if id != nil { - log.Write(EventAuthentication, "WARNING: Service Fabric doesn't support selecting a user-assigned identity at runtime") - if id.idKind() == miResourceID { - q.Add(miResID, id.String()) - } else { - q.Add(qpClientID, id.String()) - } - } request.Raw().URL.RawQuery = q.Encode() return request, nil } @@ -463,7 +465,7 @@ func (c *managedIdentityClient) getAzureArcSecretKey(ctx context.Context, resour return string(key), nil } -func (c *managedIdentityClient) createAzureArcAuthRequest(ctx context.Context, id ManagedIDKind, resources []string, key string) (*policy.Request, error) { +func (c *managedIdentityClient) createAzureArcAuthRequest(ctx context.Context, resources []string, key string) (*policy.Request, error) { request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint) if err != nil { return nil, err @@ -473,19 +475,11 @@ func (c *managedIdentityClient) createAzureArcAuthRequest(ctx context.Context, i q := request.Raw().URL.Query() q.Add("api-version", azureArcAPIVersion) q.Add("resource", strings.Join(resources, " ")) - if id != nil { - log.Write(EventAuthentication, "WARNING: Azure Arc doesn't support user-assigned managed identities") - if id.idKind() == miResourceID { - q.Add(miResID, id.String()) - } else { - q.Add(qpClientID, id.String()) - } - } request.Raw().URL.RawQuery = q.Encode() return request, nil } -func (c *managedIdentityClient) createCloudShellAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) { +func (c *managedIdentityClient) createCloudShellAuthRequest(ctx context.Context, scopes []string) (*policy.Request, error) { request, err := azruntime.NewRequest(ctx, http.MethodPost, c.endpoint) if err != nil { return nil, err @@ -498,14 +492,5 @@ func (c *managedIdentityClient) createCloudShellAuthRequest(ctx context.Context, if err := request.SetBody(body, "application/x-www-form-urlencoded"); err != nil { return nil, err } - if id != nil { - log.Write(EventAuthentication, "WARNING: Cloud Shell doesn't support user-assigned managed identities") - q := request.Raw().URL.Query() - if id.idKind() == miResourceID { - q.Add(miResID, id.String()) - } else { - q.Add(qpClientID, id.String()) - } - } return request, nil } diff --git a/sdk/azidentity/managed_identity_client_test.go b/sdk/azidentity/managed_identity_client_test.go index 78564a261cfe..78d54c4dcb42 100644 --- a/sdk/azidentity/managed_identity_client_test.go +++ b/sdk/azidentity/managed_identity_client_test.go @@ -17,7 +17,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/Azure/azure-sdk-for-go/sdk/internal/log" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -123,63 +122,3 @@ func TestManagedIdentityClient_IMDSErrors(t *testing.T) { }) } } - -func TestManagedIdentityClient_UserAssignedIDWarning(t *testing.T) { - for _, test := range []struct { - name string - createRequest func(*managedIdentityClient) error - }{ - { - name: "Azure Arc", - createRequest: func(client *managedIdentityClient) error { - _, err := client.createAzureArcAuthRequest(context.Background(), client.id, []string{liveTestScope}, "key") - return err - }, - }, - { - name: "Cloud Shell", - createRequest: func(client *managedIdentityClient) error { - _, err := client.createCloudShellAuthRequest(context.Background(), client.id, []string{liveTestScope}) - return err - }, - }, - { - name: "Service Fabric", - createRequest: func(client *managedIdentityClient) error { - _, err := client.createServiceFabricAuthRequest(context.Background(), client.id, []string{liveTestScope}) - return err - }, - }, - } { - for _, id := range []ManagedIDKind{ClientID(fakeClientID), ResourceID(fakeResourceID)} { - s := "-ClientID" - if id.String() == fakeResourceID { - s = "-ResourceID" - } - t.Run(test.name+s, func(t *testing.T) { - msgs := []string{} - log.SetListener(func(event log.Event, msg string) { - if event == EventAuthentication { - msgs = append(msgs, msg) - } - }) - client, err := newManagedIdentityClient(&ManagedIdentityCredentialOptions{ - ID: id, - }) - if err != nil { - t.Fatal(err) - } - err = test.createRequest(client) - if err != nil { - t.Fatal(err) - } - for _, msg := range msgs { - if strings.Contains(msg, test.name) && strings.Contains(msg, "user-assigned") { - return - } - } - t.Fatalf("expected warning about user-assigned ID, got:\n%s", strings.Join(msgs, "\n")) - }) - } - } -} diff --git a/sdk/azidentity/managed_identity_credential.go b/sdk/azidentity/managed_identity_credential.go index 13c043d8e0ce..83ccf8c3b571 100644 --- a/sdk/azidentity/managed_identity_credential.go +++ b/sdk/azidentity/managed_identity_credential.go @@ -32,7 +32,12 @@ type ManagedIDKind interface { idKind() managedIdentityIDKind } -// ClientID is the client ID of a user-assigned managed identity. +// ClientID is the client ID of a user-assigned managed identity. NewManagedIdentityCredential +// returns an error when a ClientID is specified on the following platforms: +// +// - Azure Arc +// - Cloud Shell +// - Service Fabric type ClientID string func (ClientID) idKind() managedIdentityIDKind { @@ -44,7 +49,13 @@ func (c ClientID) String() string { return string(c) } -// ResourceID is the resource ID of a user-assigned managed identity. +// ResourceID is the resource ID of a user-assigned managed identity. NewManagedIdentityCredential +// returns an error when a ResourceID is specified on the following platforms: +// +// - Azure Arc +// - Azure ML +// - Cloud Shell +// - Service Fabric type ResourceID string func (ResourceID) idKind() managedIdentityIDKind { @@ -60,9 +71,10 @@ func (r ResourceID) String() string { type ManagedIdentityCredentialOptions struct { azcore.ClientOptions - // ID is the ID of a managed identity the credential should authenticate. Set this field to use a specific identity - // instead of the hosting environment's default. The value may be the identity's client ID or resource ID, but note that - // some platforms don't accept resource IDs. + // ID of a managed identity the credential should authenticate. Set this field to use a specific identity instead of + // the hosting environment's default. The value may be the identity's client ID or resource ID. + // NewManagedIdentityCredential returns an error when the hosting environment doesn't support user-assigned managed + // identities, or the specified kind of ID. ID ManagedIDKind // dac indicates whether the credential is part of DefaultAzureCredential. When true, and the environment doesn't have diff --git a/sdk/azidentity/managed_identity_credential_test.go b/sdk/azidentity/managed_identity_credential_test.go index c93d7cc9d7cd..5d0c9f84073e 100644 --- a/sdk/azidentity/managed_identity_credential_test.go +++ b/sdk/azidentity/managed_identity_credential_test.go @@ -673,3 +673,38 @@ func TestManagedIdentityCredential_ServiceFabric(t *testing.T) { } testGetTokenSuccess(t, cred) } + +func TestManagedIdentityCredential_UnsupportedID(t *testing.T) { + t.Run("Azure Arc", func(t *testing.T) { + t.Setenv(identityEndpoint, fakeMIEndpoint) + t.Setenv(arcIMDSEndpoint, fakeMIEndpoint) + _, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ResourceID(fakeResourceID)}) + require.Error(t, err) + _, err = NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ClientID(fakeClientID)}) + require.Error(t, err) + }) + t.Run("Azure ML", func(t *testing.T) { + t.Setenv(msiEndpoint, fakeMIEndpoint) + t.Setenv(msiSecret, "...") + _, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ResourceID(fakeResourceID)}) + require.Error(t, err) + _, err = NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ClientID(fakeClientID)}) + require.NoError(t, err) + }) + t.Run("Cloud Shell", func(t *testing.T) { + t.Setenv(msiEndpoint, fakeMIEndpoint) + _, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ResourceID(fakeResourceID)}) + require.Error(t, err) + _, err = NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ClientID(fakeClientID)}) + require.Error(t, err) + }) + t.Run("Service Fabric", func(t *testing.T) { + t.Setenv(identityEndpoint, fakeMIEndpoint) + t.Setenv(identityHeader, "...") + t.Setenv(identityServerThumbprint, "...") + _, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ResourceID(fakeResourceID)}) + require.Error(t, err) + _, err = NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ClientID(fakeClientID)}) + require.Error(t, err) + }) +}