Skip to content

Commit

Permalink
remove context since it is not needed
Browse files Browse the repository at this point in the history
  • Loading branch information
vinay-gopalan committed May 8, 2024
1 parent 0878933 commit 4977b54
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 68 deletions.
46 changes: 24 additions & 22 deletions azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ import (
"github.com/coreos/go-oidc"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault-plugin-auth-azure/client"
"github.com/hashicorp/vault/sdk/helper/pluginidentityutil"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/helper/useragent"
"github.com/hashicorp/vault/sdk/logical"
"golang.org/x/oauth2"

"github.com/hashicorp/vault-plugin-auth-azure/client"
)

// https://learn.microsoft.com/en-us/graph/sdks/national-clouds
Expand All @@ -45,12 +46,12 @@ const (

type provider interface {
TokenVerifier() client.TokenVerifier
ComputeClient(ctx context.Context, subscriptionID string) (client.ComputeClient, error)
VMSSClient(ctx context.Context, subscriptionID string) (client.VMSSClient, error)
MSIClient(ctx context.Context, subscriptionID string) (client.MSIClient, error)
MSGraphClient(ctx context.Context) (client.MSGraphClient, error)
ResourceClient(ctx context.Context, subscriptionID string) (client.ResourceClient, error)
ProvidersClient(ctx context.Context, subscriptionID string) (client.ProvidersClient, error)
ComputeClient(subscriptionID string) (client.ComputeClient, error)
VMSSClient(subscriptionID string) (client.VMSSClient, error)
MSIClient(subscriptionID string) (client.MSIClient, error)
MSGraphClient() (client.MSGraphClient, error)
ResourceClient(subscriptionID string) (client.ResourceClient, error)
ProvidersClient(subscriptionID string) (client.ProvidersClient, error)
}

type azureProvider struct {
Expand Down Expand Up @@ -150,8 +151,8 @@ func (p *azureProvider) TokenVerifier() client.TokenVerifier {
return p.oidcVerifier
}

func (p *azureProvider) MSGraphClient(ctx context.Context) (client.MSGraphClient, error) {
cred, err := p.getTokenCredential(ctx)
func (p *azureProvider) MSGraphClient() (client.MSGraphClient, error) {
cred, err := p.getTokenCredential()
if err != nil {
return nil, err
}
Expand All @@ -164,8 +165,8 @@ func (p *azureProvider) MSGraphClient(ctx context.Context) (client.MSGraphClient
return msGraphAppClient, nil
}

func (p *azureProvider) ComputeClient(ctx context.Context, subscriptionID string) (client.ComputeClient, error) {
cred, err := p.getTokenCredential(ctx)
func (p *azureProvider) ComputeClient(subscriptionID string) (client.ComputeClient, error) {
cred, err := p.getTokenCredential()
if err != nil {
return nil, err
}
Expand All @@ -179,8 +180,8 @@ func (p *azureProvider) ComputeClient(ctx context.Context, subscriptionID string
return client, nil
}

func (p *azureProvider) VMSSClient(ctx context.Context, subscriptionID string) (client.VMSSClient, error) {
cred, err := p.getTokenCredential(ctx)
func (p *azureProvider) VMSSClient(subscriptionID string) (client.VMSSClient, error) {
cred, err := p.getTokenCredential()
if err != nil {
return nil, err
}
Expand All @@ -194,8 +195,8 @@ func (p *azureProvider) VMSSClient(ctx context.Context, subscriptionID string) (
return client, nil
}

func (p *azureProvider) MSIClient(ctx context.Context, subscriptionID string) (client.MSIClient, error) {
cred, err := p.getTokenCredential(ctx)
func (p *azureProvider) MSIClient(subscriptionID string) (client.MSIClient, error) {
cred, err := p.getTokenCredential()
if err != nil {
return nil, err
}
Expand All @@ -209,8 +210,8 @@ func (p *azureProvider) MSIClient(ctx context.Context, subscriptionID string) (c
return client, nil
}

func (p *azureProvider) ProvidersClient(ctx context.Context, subscriptionID string) (client.ProvidersClient, error) {
cred, err := p.getTokenCredential(ctx)
func (p *azureProvider) ProvidersClient(subscriptionID string) (client.ProvidersClient, error) {
cred, err := p.getTokenCredential()
if err != nil {
return nil, err
}
Expand All @@ -224,8 +225,8 @@ func (p *azureProvider) ProvidersClient(ctx context.Context, subscriptionID stri
return client, nil
}

func (p *azureProvider) ResourceClient(ctx context.Context, subscriptionID string) (client.ResourceClient, error) {
cred, err := p.getTokenCredential(ctx)
func (p *azureProvider) ResourceClient(subscriptionID string) (client.ResourceClient, error) {
cred, err := p.getTokenCredential()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -256,7 +257,7 @@ func (p *azureProvider) getClientOptions() *arm.ClientOptions {
}
}

func (p *azureProvider) getTokenCredential(ctx context.Context) (azcore.TokenCredential, error) {
func (p *azureProvider) getTokenCredential() (azcore.TokenCredential, error) {
clientCloudOpts := azcore.ClientOptions{Cloud: p.settings.CloudConfig}

if p.settings.ClientSecret != "" {
Expand All @@ -277,7 +278,7 @@ func (p *azureProvider) getTokenCredential(ctx context.Context) (azcore.TokenCre
options := &azidentity.ClientAssertionCredentialOptions{
ClientOptions: clientCloudOpts,
}
getAssertion := getAssertionFunc(ctx, p.logger, p.systemView, p.settings)
getAssertion := getAssertionFunc(p.logger, p.systemView, p.settings)
cred, err := azidentity.NewClientAssertionCredential(
p.settings.TenantID,
p.settings.ClientID,
Expand Down Expand Up @@ -305,7 +306,7 @@ func (p *azureProvider) getTokenCredential(ctx context.Context) (azcore.TokenCre

type getAssertion func(context.Context) (string, error)

func getAssertionFunc(ctx context.Context, logger hclog.Logger, sys logical.SystemView, s *azureSettings) getAssertion {
func getAssertionFunc(logger hclog.Logger, sys logical.SystemView, s *azureSettings) getAssertion {
return func(ctx context.Context) (string, error) {
req := &pluginutil.IdentityTokenRequest{
Audience: s.IdentityTokenAudience,
Expand Down Expand Up @@ -381,6 +382,7 @@ func (b *azureAuthBackend) getAzureSettings(ctx context.Context, config *azureCo
settings.ClientSecret = clientSecret

settings.IdentityTokenAudience = config.IdentityTokenAudience
settings.IdentityTokenTTL = config.IdentityTokenTTL

environment := os.Getenv("AZURE_ENVIRONMENT")
if environment == "" {
Expand Down
35 changes: 18 additions & 17 deletions azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/coreos/go-oidc"

"github.com/hashicorp/vault-plugin-auth-azure/client"
)

Expand Down Expand Up @@ -67,21 +68,21 @@ type mockProvidersClient struct {

func (c *mockComputeClient) Get(ctx context.Context, _, vmName string, _ *armcompute.VirtualMachinesClientGetOptions) (armcompute.VirtualMachinesClientGetResponse, error) {
if c.computeClientFunc != nil {
return c.computeClientFunc(ctx, vmName)
return c.computeClientFunc(vmName)
}
return armcompute.VirtualMachinesClientGetResponse{}, nil
}

func (c *mockVMSSClient) Get(ctx context.Context, _, vmssName string, _ *armcompute.VirtualMachineScaleSetsClientGetOptions) (armcompute.VirtualMachineScaleSetsClientGetResponse, error) {
if c.vmssClientFunc != nil {
return c.vmssClientFunc(ctx, vmssName)
return c.vmssClientFunc(vmssName)
}
return armcompute.VirtualMachineScaleSetsClientGetResponse{}, nil
}

func (c *mockMSIClient) Get(ctx context.Context, _, resourceName string, _ *armmsi.UserAssignedIdentitiesClientGetOptions) (armmsi.UserAssignedIdentitiesClientGetResponse, error) {
if c.msiClientFunc != nil {
return c.msiClientFunc(ctx, resourceName)
return c.msiClientFunc(resourceName)
}
return armmsi.UserAssignedIdentitiesClientGetResponse{}, nil
}
Expand All @@ -103,31 +104,31 @@ func (c *mockMSIClient) NewListByResourceGroupPager(resourceGroup string, _ *arm

func (c *mockResourceClient) GetByID(ctx context.Context, resourceID, _ string, _ *armresources.ClientGetByIDOptions) (armresources.ClientGetByIDResponse, error) {
if c.resourceClientFunc != nil {
return c.resourceClientFunc(ctx, resourceID)
return c.resourceClientFunc(resourceID)
}
return armresources.ClientGetByIDResponse{}, nil
}

func (c *mockProvidersClient) Get(ctx context.Context, resourceID string, _ *armresources.ProvidersClientGetOptions) (armresources.ProvidersClientGetResponse, error) {
if c.providersClientFunc != nil {
return c.providersClientFunc(ctx, resourceID)
return c.providersClientFunc(resourceID)
}
return armresources.ProvidersClientGetResponse{}, nil
}

type computeClientFunc func(ctx context.Context, vmName string) (armcompute.VirtualMachinesClientGetResponse, error)
type computeClientFunc func(vmName string) (armcompute.VirtualMachinesClientGetResponse, error)

type vmssClientFunc func(ctx context.Context, vmssName string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error)
type vmssClientFunc func(vmssName string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error)

type msiClientFunc func(ctx context.Context, resourceName string) (armmsi.UserAssignedIdentitiesClientGetResponse, error)
type msiClientFunc func(resourceName string) (armmsi.UserAssignedIdentitiesClientGetResponse, error)

type msiListFunc func(resoucename string) armmsi.UserAssignedIdentitiesClientListByResourceGroupResponse

type msGraphClientFunc func(ctx context.Context) (client.MSGraphClient, error)
type msGraphClientFunc func() (client.MSGraphClient, error)

type resourceClientFunc func(ctx context.Context, resourceID string) (armresources.ClientGetByIDResponse, error)
type resourceClientFunc func(resourceID string) (armresources.ClientGetByIDResponse, error)

type providersClientFunc func(ctx context.Context, s string) (armresources.ProvidersClientGetResponse, error)
type providersClientFunc func(s string) (armresources.ProvidersClientGetResponse, error)

type mockProvider struct {
computeClientFunc
Expand All @@ -153,36 +154,36 @@ func (*mockProvider) TokenVerifier() client.TokenVerifier {
return newMockVerifier()
}

func (p *mockProvider) ComputeClient(ctx context.Context, s string) (client.ComputeClient, error) {
func (p *mockProvider) ComputeClient(subscriptionID string) (client.ComputeClient, error) {
return &mockComputeClient{
computeClientFunc: p.computeClientFunc,
}, nil
}

func (p *mockProvider) VMSSClient(ctx context.Context, s string) (client.VMSSClient, error) {
func (p *mockProvider) VMSSClient(subscriptionID string) (client.VMSSClient, error) {
return &mockVMSSClient{
vmssClientFunc: p.vmssClientFunc,
}, nil
}

func (p *mockProvider) MSIClient(ctx context.Context, s string) (client.MSIClient, error) {
func (p *mockProvider) MSIClient(subscriptionID string) (client.MSIClient, error) {
return &mockMSIClient{
msiClientFunc: p.msiClientFunc,
msiListFunc: p.msiListFunc,
}, nil
}

func (p *mockProvider) MSGraphClient(ctx context.Context) (client.MSGraphClient, error) {
func (p *mockProvider) MSGraphClient() (client.MSGraphClient, error) {
return nil, nil
}

func (p *mockProvider) ResourceClient(ctx context.Context, s string) (client.ResourceClient, error) {
func (p *mockProvider) ResourceClient(subscriptionID string) (client.ResourceClient, error) {
return &mockResourceClient{
resourceClientFunc: p.resourceClientFunc,
}, nil
}

func (p *mockProvider) ProvidersClient(ctx context.Context, s string) (client.ProvidersClient, error) {
func (p *mockProvider) ProvidersClient(subscriptionID string) (client.ProvidersClient, error) {
return &mockProvidersClient{
providersClientFunc: p.providersClientFunc,
}, nil
Expand Down
2 changes: 1 addition & 1 deletion backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (b *azureAuthBackend) periodicFunc(ctx context.Context, req *logical.Reques
return err
}

client, err := provider.MSGraphClient(ctx)
client, err := provider.MSGraphClient()
if err != nil {
return err
}
Expand Down
12 changes: 6 additions & 6 deletions path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ func (b *azureAuthBackend) verifyResource(ctx context.Context, subscriptionID, r
// If vmss name is specified, the vm name will be ignored and only the scale set
// will be verified since vm names are generated automatically for scale sets
case vmssName != "":
client, err := b.provider.VMSSClient(ctx, subscriptionID)
client, err := b.provider.VMSSClient(subscriptionID)
if err != nil {
return err
}
Expand Down Expand Up @@ -320,7 +320,7 @@ func (b *azureAuthBackend) verifyResource(ctx context.Context, subscriptionID, r

// Principal ID is nil for VMSS flex orchestration mode, so we
// must look up the user-assigned identity using the MSI client
msiClient, err := b.provider.MSIClient(ctx, msiID.SubscriptionID)
msiClient, err := b.provider.MSIClient(msiID.SubscriptionID)
if err != nil {
return fmt.Errorf("failed to create client to retrieve user-assigned identity: %w", err)
}
Expand All @@ -334,7 +334,7 @@ func (b *azureAuthBackend) verifyResource(ctx context.Context, subscriptionID, r
}
}
case vmName != "":
client, err := b.provider.ComputeClient(ctx, subscriptionID)
client, err := b.provider.ComputeClient(subscriptionID)
if err != nil {
return err
}
Expand Down Expand Up @@ -378,7 +378,7 @@ func (b *azureAuthBackend) verifyResource(ctx context.Context, subscriptionID, r
return err
}

client, err := b.provider.ResourceClient(ctx, subscriptionID)
client, err := b.provider.ResourceClient(subscriptionID)
if err != nil {
return err
}
Expand Down Expand Up @@ -420,7 +420,7 @@ func (b *azureAuthBackend) verifyResource(ctx context.Context, subscriptionID, r
}

clientIDs := map[string]struct{}{}
c, err := b.provider.MSIClient(ctx, subscriptionID)
c, err := b.provider.MSIClient(subscriptionID)
if err != nil {
return fmt.Errorf("failed to create client to retrieve app ids: %w", err)
}
Expand Down Expand Up @@ -536,7 +536,7 @@ func (b *azureAuthBackend) getAPIVersionForResource(ctx context.Context, subscri
}
b.cacheLock.RUnlock()

client, err := b.provider.ProvidersClient(ctx, subscriptionID)
client, err := b.provider.ProvidersClient(subscriptionID)
if err != nil {
return "", err
}
Expand Down
Loading

0 comments on commit 4977b54

Please sign in to comment.