Skip to content

Commit

Permalink
Correctly use GetApplication with ApplicationObjectID instead of in…
Browse files Browse the repository at this point in the history
…correct filter (#200)

* fix getApplications method to properly fetch by ID instead of incorrect filter

* add integration test for role creation using object ID

* update signature and test to use applicationObjectID instead of clientID

* Update path_roles_test.go

Co-authored-by: Milena Zlaticanin <[email protected]>

---------

Co-authored-by: Ellie Sterner <[email protected]>
Co-authored-by: Milena Zlaticanin <[email protected]>
  • Loading branch information
3 people authored May 7, 2024
1 parent 560399b commit e061a81
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 21 deletions.
27 changes: 9 additions & 18 deletions api/applications.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
)

type ApplicationsClient interface {
GetApplication(ctx context.Context, clientID string) (Application, error)
GetApplication(ctx context.Context, applicationObjectID string) (Application, error)
CreateApplication(ctx context.Context, displayName string, signInAudience string, tags []string) (Application, error)
DeleteApplication(ctx context.Context, applicationObjectID string, permanentlyDelete bool) error
ListApplications(ctx context.Context, filter string) ([]Application, error)
Expand Down Expand Up @@ -73,30 +73,21 @@ func NewMSGraphClient(graphURI string, creds azcore.TokenCredential) (*MSGraphCl
return ac, nil
}

func (c *MSGraphClient) GetApplication(ctx context.Context, clientID string) (Application, error) {
filter := fmt.Sprintf("appId eq '%s'", clientID)
req := applications.ApplicationsRequestBuilderGetRequestConfiguration{
QueryParameters: &applications.ApplicationsRequestBuilderGetQueryParameters{
Filter: &filter,
},
}

resp, err := c.client.Applications().Get(ctx, &req)
func (c *MSGraphClient) GetApplication(ctx context.Context, applicationObjectID string) (Application, error) {
app, err := c.client.Applications().ByApplicationId(applicationObjectID).Get(ctx, nil)
if err != nil {
return Application{}, err
}

apps := resp.GetValue()
if len(apps) == 0 {
if app == nil {
return Application{}, fmt.Errorf("no application found")
}
if len(apps) > 1 {
return Application{}, fmt.Errorf("multiple applications found - double check your client_id")
}

app := apps[0]

return getApplicationResponse(app), nil
return Application{
AppID: *app.GetAppId(),
AppObjectID: *app.GetId(),
PasswordCredentials: getPasswordCredentialsForApplication(app),
}, nil
}

func (c *MSGraphClient) ListApplications(ctx context.Context, filter string) ([]Application, error) {
Expand Down
86 changes: 86 additions & 0 deletions path_roles_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@ package azuresecrets
import (
"context"
"fmt"
"os"
"sort"
"strings"
"testing"
"time"

log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
)

Expand Down Expand Up @@ -557,6 +560,89 @@ func TestRoleCreateBad(t *testing.T) {
}
}

// TestRolesCreate_applicationObjectID tests that a role
// can be created using the Application Object ID field.
// This test requires valid, sufficiently-privileged
// Azure credentials in env variables.
func TestRolesCreate_applicationObjectID(t *testing.T) {
if os.Getenv("VAULT_ACC") != "1" {
t.SkipNow()
}

if os.Getenv("AZURE_CLIENT_SECRET") == "" {
t.Skip("Azure Secrets: Azure environment variables not set. Skipping.")
}

t.Run("service principals", func(t *testing.T) {
t.Parallel()

skipIfMissingEnvVars(t,
"AZURE_SUBSCRIPTION_ID",
"AZURE_CLIENT_ID",
"AZURE_CLIENT_SECRET",
"AZURE_TENANT_ID",
"AZURE_TEST_RESOURCE_GROUP",
)

b := backend()
s := new(logical.InmemStorage)
subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID")
clientID := os.Getenv("AZURE_CLIENT_ID")
clientSecret := os.Getenv("AZURE_CLIENT_SECRET")
tenantID := os.Getenv("AZURE_TENANT_ID")
appObjectID := os.Getenv("AZURE_APPLICATION_OBJECT_ID")

config := &logical.BackendConfig{
Logger: logging.NewVaultLogger(log.Trace),
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLHr,
MaxLeaseTTLVal: maxLeaseTTLHr,
},
StorageView: s,
}
err := b.Setup(context.Background(), config)
assertErrorIsNil(t, err)

configData := map[string]interface{}{
"application_object_id": subscriptionID,
"client_id": clientID,
"client_secret": clientSecret,
"tenant_id": tenantID,
}

configResp, err := b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.CreateOperation,
Path: "config",
Data: configData,
Storage: s,
})
assertRespNoError(t, configResp, err)

roleName := "test_role_object_id"

roleData := map[string]interface{}{
"application_object_id": appObjectID,
"ttl": "1h",
}

roleResp, err := b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.CreateOperation,
Path: fmt.Sprintf("roles/%s", roleName),
Data: roleData,
Storage: s,
})
assertRespNoError(t, roleResp, err)

credsResp, err := b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.ReadOperation,
Path: fmt.Sprintf("creds/%s", roleName),
Storage: s,
})
assertRespNoError(t, credsResp, err)

})
}

func TestValidateTags(t *testing.T) {
tests := []struct {
name string
Expand Down
6 changes: 3 additions & 3 deletions provider_mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,15 @@ func (m *mockProvider) CreateApplication(_ context.Context, _ string, _ string,
}, nil
}

func (m *mockProvider) GetApplication(_ context.Context, clientID string) (api.Application, error) {
func (m *mockProvider) GetApplication(_ context.Context, applicationObjectID string) (api.Application, error) {
m.lock.Lock()
defer m.lock.Unlock()

appID := m.applications[clientID]
appID := m.applications[applicationObjectID]

return api.Application{
AppID: appID,
AppObjectID: clientID,
AppObjectID: applicationObjectID,
}, nil
}

Expand Down

0 comments on commit e061a81

Please sign in to comment.