diff --git a/api/applications.go b/api/applications.go index fd575fab..9ead54f2 100644 --- a/api/applications.go +++ b/api/applications.go @@ -17,7 +17,7 @@ import ( ) type ApplicationsClient interface { - GetApplication(ctx context.Context, clientID string) (Application, error) + GetApplication(ctx context.Context, applicationObjectID string) (Application, error) CreateApplication(ctx context.Context, displayName string, signInAudience string, tags []string) (Application, error) DeleteApplication(ctx context.Context, applicationObjectID string, permanentlyDelete bool) error ListApplications(ctx context.Context, filter string) ([]Application, error) @@ -73,30 +73,21 @@ func NewMSGraphClient(graphURI string, creds azcore.TokenCredential) (*MSGraphCl return ac, nil } -func (c *MSGraphClient) GetApplication(ctx context.Context, clientID string) (Application, error) { - filter := fmt.Sprintf("appId eq '%s'", clientID) - req := applications.ApplicationsRequestBuilderGetRequestConfiguration{ - QueryParameters: &applications.ApplicationsRequestBuilderGetQueryParameters{ - Filter: &filter, - }, - } - - resp, err := c.client.Applications().Get(ctx, &req) +func (c *MSGraphClient) GetApplication(ctx context.Context, applicationObjectID string) (Application, error) { + app, err := c.client.Applications().ByApplicationId(applicationObjectID).Get(ctx, nil) if err != nil { return Application{}, err } - apps := resp.GetValue() - if len(apps) == 0 { + if app == nil { return Application{}, fmt.Errorf("no application found") } - if len(apps) > 1 { - return Application{}, fmt.Errorf("multiple applications found - double check your client_id") - } - app := apps[0] - - return getApplicationResponse(app), nil + return Application{ + AppID: *app.GetAppId(), + AppObjectID: *app.GetId(), + PasswordCredentials: getPasswordCredentialsForApplication(app), + }, nil } func (c *MSGraphClient) ListApplications(ctx context.Context, filter string) ([]Application, error) { diff --git a/path_roles_test.go b/path_roles_test.go index 3cb89358..42ca7b43 100644 --- a/path_roles_test.go +++ b/path_roles_test.go @@ -6,11 +6,14 @@ package azuresecrets import ( "context" "fmt" + "os" "sort" "strings" "testing" "time" + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/sdk/helper/logging" "github.com/hashicorp/vault/sdk/logical" ) @@ -557,6 +560,89 @@ func TestRoleCreateBad(t *testing.T) { } } +// TestRolesCreate_applicationObjectID tests that a role +// can be created using the Application Object ID field. +// This test requires valid, sufficiently-privileged +// Azure credentials in env variables. +func TestRolesCreate_applicationObjectID(t *testing.T) { + if os.Getenv("VAULT_ACC") != "1" { + t.SkipNow() + } + + if os.Getenv("AZURE_CLIENT_SECRET") == "" { + t.Skip("Azure Secrets: Azure environment variables not set. Skipping.") + } + + t.Run("service principals", func(t *testing.T) { + t.Parallel() + + skipIfMissingEnvVars(t, + "AZURE_SUBSCRIPTION_ID", + "AZURE_CLIENT_ID", + "AZURE_CLIENT_SECRET", + "AZURE_TENANT_ID", + "AZURE_TEST_RESOURCE_GROUP", + ) + + b := backend() + s := new(logical.InmemStorage) + subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID") + clientID := os.Getenv("AZURE_CLIENT_ID") + clientSecret := os.Getenv("AZURE_CLIENT_SECRET") + tenantID := os.Getenv("AZURE_TENANT_ID") + appObjectID := os.Getenv("AZURE_APPLICATION_OBJECT_ID") + + config := &logical.BackendConfig{ + Logger: logging.NewVaultLogger(log.Trace), + System: &logical.StaticSystemView{ + DefaultLeaseTTLVal: defaultLeaseTTLHr, + MaxLeaseTTLVal: maxLeaseTTLHr, + }, + StorageView: s, + } + err := b.Setup(context.Background(), config) + assertErrorIsNil(t, err) + + configData := map[string]interface{}{ + "application_object_id": subscriptionID, + "client_id": clientID, + "client_secret": clientSecret, + "tenant_id": tenantID, + } + + configResp, err := b.HandleRequest(context.Background(), &logical.Request{ + Operation: logical.CreateOperation, + Path: "config", + Data: configData, + Storage: s, + }) + assertRespNoError(t, configResp, err) + + roleName := "test_role_object_id" + + roleData := map[string]interface{}{ + "application_object_id": appObjectID, + "ttl": "1h", + } + + roleResp, err := b.HandleRequest(context.Background(), &logical.Request{ + Operation: logical.CreateOperation, + Path: fmt.Sprintf("roles/%s", roleName), + Data: roleData, + Storage: s, + }) + assertRespNoError(t, roleResp, err) + + credsResp, err := b.HandleRequest(context.Background(), &logical.Request{ + Operation: logical.ReadOperation, + Path: fmt.Sprintf("creds/%s", roleName), + Storage: s, + }) + assertRespNoError(t, credsResp, err) + + }) +} + func TestValidateTags(t *testing.T) { tests := []struct { name string diff --git a/provider_mock_test.go b/provider_mock_test.go index 316d7045..46877d0a 100644 --- a/provider_mock_test.go +++ b/provider_mock_test.go @@ -135,15 +135,15 @@ func (m *mockProvider) CreateApplication(_ context.Context, _ string, _ string, }, nil } -func (m *mockProvider) GetApplication(_ context.Context, clientID string) (api.Application, error) { +func (m *mockProvider) GetApplication(_ context.Context, applicationObjectID string) (api.Application, error) { m.lock.Lock() defer m.lock.Unlock() - appID := m.applications[clientID] + appID := m.applications[applicationObjectID] return api.Application{ AppID: appID, - AppObjectID: clientID, + AppObjectID: applicationObjectID, }, nil }