diff --git a/azure/scope/cluster_test.go b/azure/scope/cluster_test.go index de929b7b459..ad4fdaf0806 100644 --- a/azure/scope/cluster_test.go +++ b/azure/scope/cluster_test.go @@ -42,9 +42,20 @@ import ( func specToString(spec azure.ResourceSpecGetter) string { var sb strings.Builder - sb.WriteString("[ ") + sb.WriteString("{ ") sb.WriteString(fmt.Sprintf("%+v ", spec)) + sb.WriteString("}") + return sb.String() +} + +func specArrayToString(specs []azure.ResourceSpecGetter) string { + var sb strings.Builder + sb.WriteString("[\n") + for _, spec := range specs { + sb.WriteString(fmt.Sprintf("\t%+v\n", specToString(spec))) + } sb.WriteString("]") + return sb.String() } diff --git a/azure/scope/machine.go b/azure/scope/machine.go index 318cafe4e49..2612f893e1c 100644 --- a/azure/scope/machine.go +++ b/azure/scope/machine.go @@ -37,6 +37,7 @@ import ( "sigs.k8s.io/cluster-api-provider-azure/azure/services/resourceskus" "sigs.k8s.io/cluster-api-provider-azure/azure/services/roleassignments" "sigs.k8s.io/cluster-api-provider-azure/azure/services/virtualmachines" + "sigs.k8s.io/cluster-api-provider-azure/azure/services/vmextensions" "sigs.k8s.io/cluster-api-provider-azure/util/futures" "sigs.k8s.io/cluster-api-provider-azure/util/tele" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" @@ -317,13 +318,17 @@ func (m *MachineScope) HasSystemAssignedIdentity() bool { return m.AzureMachine.Spec.Identity == infrav1.VMIdentitySystemAssigned } -// VMExtensionSpecs returns the vm extension specs. -func (m *MachineScope) VMExtensionSpecs() []azure.ExtensionSpec { - var extensionSpecs = []azure.ExtensionSpec{} - extensionSpec := azure.GetBootstrappingVMExtension(m.AzureMachine.Spec.OSDisk.OSType, m.CloudEnvironment(), m.Name()) +// VMExtensionSpecs returns the VM extension specs. +func (m *MachineScope) VMExtensionSpecs() []azure.ResourceSpecGetter { + var extensionSpecs = []azure.ResourceSpecGetter{} + bootstrapExtensionSpec := azure.GetBootstrappingVMExtension(m.AzureMachine.Spec.OSDisk.OSType, m.CloudEnvironment(), m.Name()) - if extensionSpec != nil { - extensionSpecs = append(extensionSpecs, *extensionSpec) + if bootstrapExtensionSpec != nil { + extensionSpecs = append(extensionSpecs, &vmextensions.VMExtensionSpec{ + ExtensionSpec: *bootstrapExtensionSpec, + ResourceGroup: m.ResourceGroup(), + Location: m.Location(), + }) } return extensionSpecs diff --git a/azure/scope/machine_test.go b/azure/scope/machine_test.go index 2126ad5a22c..70e77ffdfbb 100644 --- a/azure/scope/machine_test.go +++ b/azure/scope/machine_test.go @@ -18,9 +18,7 @@ package scope import ( "context" - "fmt" "reflect" - "strings" "testing" autorestazure "github.com/Azure/go-autorest/autorest/azure" @@ -36,20 +34,10 @@ import ( "sigs.k8s.io/cluster-api-provider-azure/azure/services/networkinterfaces" "sigs.k8s.io/cluster-api-provider-azure/azure/services/resourceskus" "sigs.k8s.io/cluster-api-provider-azure/azure/services/roleassignments" + "sigs.k8s.io/cluster-api-provider-azure/azure/services/vmextensions" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" ) -func specArrayToString(specs []azure.ResourceSpecGetter) string { - var sb strings.Builder - sb.WriteString("[ ") - for _, spec := range specs { - sb.WriteString(fmt.Sprintf("%+v ", spec)) - } - sb.WriteString("]") - - return sb.String() -} - func TestMachineScope_Name(t *testing.T) { tests := []struct { name string @@ -420,6 +408,9 @@ func TestMachineScope_RoleAssignmentSpecs(t *testing.T) { AzureCluster: &infrav1.AzureCluster{ Spec: infrav1.AzureClusterSpec{ ResourceGroup: "my-rg", + AzureClusterClassSpec: infrav1.AzureClusterClassSpec{ + Location: "westus", + }, }, }, }, @@ -450,7 +441,7 @@ func TestMachineScope_VMExtensionSpecs(t *testing.T) { tests := []struct { name string machineScope MachineScope - want []azure.ExtensionSpec + want []azure.ResourceSpecGetter }{ { name: "If OS type is Linux and cloud is AzurePublicCloud, it returns ExtensionSpec", @@ -474,17 +465,29 @@ func TestMachineScope_VMExtensionSpecs(t *testing.T) { }, }, }, + AzureCluster: &infrav1.AzureCluster{ + Spec: infrav1.AzureClusterSpec{ + ResourceGroup: "my-rg", + AzureClusterClassSpec: infrav1.AzureClusterClassSpec{ + Location: "westus", + }, + }, + }, }, }, - want: []azure.ExtensionSpec{ - { - Name: "CAPZ.Linux.Bootstrapping", - VMName: "machine-name", - Publisher: "Microsoft.Azure.ContainerUpstream", - Version: "1.0", - ProtectedSettings: map[string]string{ - "commandToExecute": azure.LinuxBootstrapExtensionCommand, + want: []azure.ResourceSpecGetter{ + &vmextensions.VMExtensionSpec{ + ExtensionSpec: azure.ExtensionSpec{ + Name: "CAPZ.Linux.Bootstrapping", + VMName: "machine-name", + Publisher: "Microsoft.Azure.ContainerUpstream", + Version: "1.0", + ProtectedSettings: map[string]string{ + "commandToExecute": azure.LinuxBootstrapExtensionCommand, + }, }, + ResourceGroup: "my-rg", + Location: "westus", }, }, }, @@ -510,9 +513,17 @@ func TestMachineScope_VMExtensionSpecs(t *testing.T) { }, }, }, + AzureCluster: &infrav1.AzureCluster{ + Spec: infrav1.AzureClusterSpec{ + ResourceGroup: "my-rg", + AzureClusterClassSpec: infrav1.AzureClusterClassSpec{ + Location: "westus", + }, + }, + }, }, }, - want: []azure.ExtensionSpec{}, + want: []azure.ResourceSpecGetter{}, }, { name: "If OS type is Windows and cloud is AzurePublicCloud, it returns ExtensionSpec", @@ -536,17 +547,29 @@ func TestMachineScope_VMExtensionSpecs(t *testing.T) { }, }, }, + AzureCluster: &infrav1.AzureCluster{ + Spec: infrav1.AzureClusterSpec{ + ResourceGroup: "my-rg", + AzureClusterClassSpec: infrav1.AzureClusterClassSpec{ + Location: "westus", + }, + }, + }, }, }, - want: []azure.ExtensionSpec{ - { - Name: "CAPZ.Windows.Bootstrapping", - VMName: "machine-name", - Publisher: "Microsoft.Azure.ContainerUpstream", - Version: "1.0", - ProtectedSettings: map[string]string{ - "commandToExecute": azure.WindowsBootstrapExtensionCommand, + want: []azure.ResourceSpecGetter{ + &vmextensions.VMExtensionSpec{ + ExtensionSpec: azure.ExtensionSpec{ + Name: "CAPZ.Windows.Bootstrapping", + VMName: "machine-name", + Publisher: "Microsoft.Azure.ContainerUpstream", + Version: "1.0", + ProtectedSettings: map[string]string{ + "commandToExecute": azure.WindowsBootstrapExtensionCommand, + }, }, + ResourceGroup: "my-rg", + Location: "westus", }, }, }, @@ -572,9 +595,17 @@ func TestMachineScope_VMExtensionSpecs(t *testing.T) { }, }, }, + AzureCluster: &infrav1.AzureCluster{ + Spec: infrav1.AzureClusterSpec{ + ResourceGroup: "my-rg", + AzureClusterClassSpec: infrav1.AzureClusterClassSpec{ + Location: "westus", + }, + }, + }, }, }, - want: []azure.ExtensionSpec{}, + want: []azure.ResourceSpecGetter{}, }, { name: "If OS type is not Linux or Windows and cloud is AzurePublicCloud, it returns empty", @@ -598,9 +629,17 @@ func TestMachineScope_VMExtensionSpecs(t *testing.T) { }, }, }, + AzureCluster: &infrav1.AzureCluster{ + Spec: infrav1.AzureClusterSpec{ + ResourceGroup: "my-rg", + AzureClusterClassSpec: infrav1.AzureClusterClassSpec{ + Location: "westus", + }, + }, + }, }, }, - want: []azure.ExtensionSpec{}, + want: []azure.ResourceSpecGetter{}, }, { name: "If OS type is not Windows or Linux and cloud is not AzurePublicCloud, it returns empty", @@ -624,15 +663,23 @@ func TestMachineScope_VMExtensionSpecs(t *testing.T) { }, }, }, + AzureCluster: &infrav1.AzureCluster{ + Spec: infrav1.AzureClusterSpec{ + ResourceGroup: "my-rg", + AzureClusterClassSpec: infrav1.AzureClusterClassSpec{ + Location: "westus", + }, + }, + }, }, }, - want: []azure.ExtensionSpec{}, + want: []azure.ResourceSpecGetter{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.machineScope.VMExtensionSpecs(); !reflect.DeepEqual(got, tt.want) { - t.Errorf("VMExtensionSpecs() = %v, want %v", got, tt.want) + t.Errorf("VMExtensionSpecs() = \n%s, want \n%s", specArrayToString(got), specArrayToString(tt.want)) } }) } diff --git a/azure/scope/machinepool.go b/azure/scope/machinepool.go index 145dae64f1b..34cc4ae244c 100644 --- a/azure/scope/machinepool.go +++ b/azure/scope/machinepool.go @@ -32,6 +32,7 @@ import ( "sigs.k8s.io/cluster-api-provider-azure/azure" machinepool "sigs.k8s.io/cluster-api-provider-azure/azure/scope/strategies/machinepool_deployments" "sigs.k8s.io/cluster-api-provider-azure/azure/services/roleassignments" + "sigs.k8s.io/cluster-api-provider-azure/azure/services/vmssextensions" infrav1exp "sigs.k8s.io/cluster-api-provider-azure/exp/api/v1beta1" "sigs.k8s.io/cluster-api-provider-azure/util/futures" "sigs.k8s.io/cluster-api-provider-azure/util/tele" @@ -597,12 +598,15 @@ func (m *MachinePoolScope) HasSystemAssignedIdentity() bool { } // VMSSExtensionSpecs returns the vmss extension specs. -func (m *MachinePoolScope) VMSSExtensionSpecs() []azure.ExtensionSpec { - var extensionSpecs = []azure.ExtensionSpec{} - extensionSpec := azure.GetBootstrappingVMExtension(m.AzureMachinePool.Spec.Template.OSDisk.OSType, m.CloudEnvironment(), m.Name()) +func (m *MachinePoolScope) VMSSExtensionSpecs() []azure.ResourceSpecGetter { + var extensionSpecs = []azure.ResourceSpecGetter{} + bootstrapExtensionSpec := azure.GetBootstrappingVMExtension(m.AzureMachinePool.Spec.Template.OSDisk.OSType, m.CloudEnvironment(), m.Name()) - if extensionSpec != nil { - extensionSpecs = append(extensionSpecs, *extensionSpec) + if bootstrapExtensionSpec != nil { + extensionSpecs = append(extensionSpecs, &vmssextensions.VMSSExtensionSpec{ + ExtensionSpec: *bootstrapExtensionSpec, + ResourceGroup: m.ResourceGroup(), + }) } return extensionSpecs diff --git a/azure/scope/machinepool_test.go b/azure/scope/machinepool_test.go index 344008e13b5..a532e22af96 100644 --- a/azure/scope/machinepool_test.go +++ b/azure/scope/machinepool_test.go @@ -33,6 +33,7 @@ import ( "k8s.io/apimachinery/pkg/util/intstr" infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" "sigs.k8s.io/cluster-api-provider-azure/azure" + "sigs.k8s.io/cluster-api-provider-azure/azure/services/vmssextensions" infrav1exp "sigs.k8s.io/cluster-api-provider-azure/exp/api/v1beta1" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" clusterv1exp "sigs.k8s.io/cluster-api/exp/api/v1beta1" @@ -624,7 +625,7 @@ func TestMachinePoolScope_VMSSExtensionSpecs(t *testing.T) { tests := []struct { name string machinePoolScope MachinePoolScope - want []azure.ExtensionSpec + want []azure.ResourceSpecGetter }{ { name: "If OS type is Linux and cloud is AzurePublicCloud, it returns ExtensionSpec", @@ -650,17 +651,25 @@ func TestMachinePoolScope_VMSSExtensionSpecs(t *testing.T) { }, }, }, + AzureCluster: &infrav1.AzureCluster{ + Spec: infrav1.AzureClusterSpec{ + ResourceGroup: "my-rg", + }, + }, }, }, - want: []azure.ExtensionSpec{ - { - Name: "CAPZ.Linux.Bootstrapping", - VMName: "machinepool-name", - Publisher: "Microsoft.Azure.ContainerUpstream", - Version: "1.0", - ProtectedSettings: map[string]string{ - "commandToExecute": azure.LinuxBootstrapExtensionCommand, + want: []azure.ResourceSpecGetter{ + &vmssextensions.VMSSExtensionSpec{ + ExtensionSpec: azure.ExtensionSpec{ + Name: "CAPZ.Linux.Bootstrapping", + VMName: "machinepool-name", + Publisher: "Microsoft.Azure.ContainerUpstream", + Version: "1.0", + ProtectedSettings: map[string]string{ + "commandToExecute": azure.LinuxBootstrapExtensionCommand, + }, }, + ResourceGroup: "my-rg", }, }, }, @@ -688,9 +697,14 @@ func TestMachinePoolScope_VMSSExtensionSpecs(t *testing.T) { }, }, }, + AzureCluster: &infrav1.AzureCluster{ + Spec: infrav1.AzureClusterSpec{ + ResourceGroup: "my-rg", + }, + }, }, }, - want: []azure.ExtensionSpec{}, + want: []azure.ResourceSpecGetter{}, }, { name: "If OS type is Windows and cloud is AzurePublicCloud, it returns ExtensionSpec", @@ -717,18 +731,26 @@ func TestMachinePoolScope_VMSSExtensionSpecs(t *testing.T) { }, }, }, + AzureCluster: &infrav1.AzureCluster{ + Spec: infrav1.AzureClusterSpec{ + ResourceGroup: "my-rg", + }, + }, }, }, - want: []azure.ExtensionSpec{ - { - Name: "CAPZ.Windows.Bootstrapping", - // Note: machine pool names longer than 9 characters get truncated. See MachinePoolScope::Name() for more details. - VMName: "winpool", - Publisher: "Microsoft.Azure.ContainerUpstream", - Version: "1.0", - ProtectedSettings: map[string]string{ - "commandToExecute": azure.WindowsBootstrapExtensionCommand, + want: []azure.ResourceSpecGetter{ + &vmssextensions.VMSSExtensionSpec{ + ExtensionSpec: azure.ExtensionSpec{ + Name: "CAPZ.Windows.Bootstrapping", + // Note: machine pool names longer than 9 characters get truncated. See MachinePoolScope::Name() for more details. + VMName: "winpool", + Publisher: "Microsoft.Azure.ContainerUpstream", + Version: "1.0", + ProtectedSettings: map[string]string{ + "commandToExecute": azure.WindowsBootstrapExtensionCommand, + }, }, + ResourceGroup: "my-rg", }, }, }, @@ -756,9 +778,14 @@ func TestMachinePoolScope_VMSSExtensionSpecs(t *testing.T) { }, }, }, + AzureCluster: &infrav1.AzureCluster{ + Spec: infrav1.AzureClusterSpec{ + ResourceGroup: "my-rg", + }, + }, }, }, - want: []azure.ExtensionSpec{}, + want: []azure.ResourceSpecGetter{}, }, { name: "If OS type is not Linux or Windows and cloud is AzurePublicCloud, it returns empty", @@ -784,9 +811,14 @@ func TestMachinePoolScope_VMSSExtensionSpecs(t *testing.T) { }, }, }, + AzureCluster: &infrav1.AzureCluster{ + Spec: infrav1.AzureClusterSpec{ + ResourceGroup: "my-rg", + }, + }, }, }, - want: []azure.ExtensionSpec{}, + want: []azure.ResourceSpecGetter{}, }, { name: "If OS type is not Windows or Linux and cloud is not AzurePublicCloud, it returns empty", @@ -812,15 +844,20 @@ func TestMachinePoolScope_VMSSExtensionSpecs(t *testing.T) { }, }, }, + AzureCluster: &infrav1.AzureCluster{ + Spec: infrav1.AzureClusterSpec{ + ResourceGroup: "my-rg", + }, + }, }, }, - want: []azure.ExtensionSpec{}, + want: []azure.ResourceSpecGetter{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.machinePoolScope.VMSSExtensionSpecs(); !reflect.DeepEqual(got, tt.want) { - t.Errorf("VMSSExtensionSpecs() = %v, want %v", got, tt.want) + t.Errorf("VMSSExtensionSpecs() = \n%s, want \n%s", specArrayToString(got), specArrayToString(tt.want)) } }) } diff --git a/azure/services/scalesets/mock_scalesets/scalesets_mock.go b/azure/services/scalesets/mock_scalesets/scalesets_mock.go index 623a616200f..97b00251d82 100644 --- a/azure/services/scalesets/mock_scalesets/scalesets_mock.go +++ b/azure/services/scalesets/mock_scalesets/scalesets_mock.go @@ -446,10 +446,10 @@ func (mr *MockScaleSetScopeMockRecorder) UpdatePutStatus(arg0, arg1, arg2 interf } // VMSSExtensionSpecs mocks base method. -func (m *MockScaleSetScope) VMSSExtensionSpecs() []azure.ExtensionSpec { +func (m *MockScaleSetScope) VMSSExtensionSpecs() []azure.ResourceSpecGetter { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "VMSSExtensionSpecs") - ret0, _ := ret[0].([]azure.ExtensionSpec) + ret0, _ := ret[0].([]azure.ResourceSpecGetter) return ret0 } diff --git a/azure/services/scalesets/scalesets.go b/azure/services/scalesets/scalesets.go index 55df787a2ec..b7279a78ac5 100644 --- a/azure/services/scalesets/scalesets.go +++ b/azure/services/scalesets/scalesets.go @@ -46,7 +46,7 @@ type ( SaveVMImageToStatus(*infrav1.Image) MaxSurge() (int, error) ScaleSetSpec() azure.ScaleSetSpec - VMSSExtensionSpecs() []azure.ExtensionSpec + VMSSExtensionSpecs() []azure.ResourceSpecGetter SetAnnotation(string, string) SetProviderID(string) SetVMSSState(*azure.VMSS) @@ -370,7 +370,10 @@ func (s *Service) buildVMSSFromSpec(ctx context.Context, vmssSpec azure.ScaleSet vmssSpec.AcceleratedNetworking = &accelNet } - extensions := s.generateExtensions() + extensions, err := s.generateExtensions() + if err != nil { + return compute.VirtualMachineScaleSet{}, err + } storageProfile, err := s.generateStorageProfile(ctx, vmssSpec, sku) if err != nil { @@ -543,22 +546,22 @@ func (s *Service) getVirtualMachineScaleSetIfDone(ctx context.Context, future *i return converters.SDKToVMSS(vmss, vmssInstances), nil } -func (s *Service) generateExtensions() []compute.VirtualMachineScaleSetExtension { +func (s *Service) generateExtensions() ([]compute.VirtualMachineScaleSetExtension, error) { extensions := make([]compute.VirtualMachineScaleSetExtension, len(s.Scope.VMSSExtensionSpecs())) for i, extensionSpec := range s.Scope.VMSSExtensionSpecs() { extensionSpec := extensionSpec - extensions[i] = compute.VirtualMachineScaleSetExtension{ - Name: &extensionSpec.Name, - VirtualMachineScaleSetExtensionProperties: &compute.VirtualMachineScaleSetExtensionProperties{ - Publisher: to.StringPtr(extensionSpec.Publisher), - Type: to.StringPtr(extensionSpec.Name), - TypeHandlerVersion: to.StringPtr(extensionSpec.Version), - Settings: nil, - ProtectedSettings: extensionSpec.ProtectedSettings, - }, + parameters, err := extensionSpec.Parameters(nil) + if err != nil { + return nil, err + } + vmssextension, ok := parameters.(compute.VirtualMachineScaleSetExtension) + if !ok { + return nil, errors.Errorf("%T is not a compute.VirtualMachineScaleSetExtension", parameters) } + extensions[i] = vmssextension } - return extensions + + return extensions, nil } // generateStorageProfile generates a pointer to a compute.VirtualMachineScaleSetStorageProfile which can utilized for VM creation. diff --git a/azure/services/scalesets/scalesets_test.go b/azure/services/scalesets/scalesets_test.go index db14ccc8663..17817c275d8 100644 --- a/azure/services/scalesets/scalesets_test.go +++ b/azure/services/scalesets/scalesets_test.go @@ -33,6 +33,7 @@ import ( "sigs.k8s.io/cluster-api-provider-azure/azure" "sigs.k8s.io/cluster-api-provider-azure/azure/services/resourceskus" "sigs.k8s.io/cluster-api-provider-azure/azure/services/scalesets/mock_scalesets" + "sigs.k8s.io/cluster-api-provider-azure/azure/services/vmssextensions" gomockinternal "sigs.k8s.io/cluster-api-provider-azure/internal/test/matchers/gomock" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" ) @@ -1218,15 +1219,18 @@ func setupVMSSExpectationsWithoutVMImage(s *mock_scalesets.MockScaleSetScopeMock s.Location().AnyTimes().Return("test-location") s.ClusterName().Return("my-cluster") s.GetBootstrapData(gomockinternal.AContext()).Return("fake-bootstrap-data", nil) - s.VMSSExtensionSpecs().Return([]azure.ExtensionSpec{ - { - Name: "someExtension", - VMName: "my-vmss", - Publisher: "somePublisher", - Version: "someVersion", - ProtectedSettings: map[string]string{ - "commandToExecute": "echo hello", + s.VMSSExtensionSpecs().Return([]azure.ResourceSpecGetter{ + &vmssextensions.VMSSExtensionSpec{ + ExtensionSpec: azure.ExtensionSpec{ + Name: "someExtension", + VMName: "my-vmss", + Publisher: "somePublisher", + Version: "someVersion", + ProtectedSettings: map[string]string{ + "commandToExecute": "echo hello", + }, }, + ResourceGroup: "my-rg", }, }).AnyTimes() } diff --git a/azure/services/vmextensions/client.go b/azure/services/vmextensions/client.go index e5a1f7a24da..5ab8f168ab1 100644 --- a/azure/services/vmextensions/client.go +++ b/azure/services/vmextensions/client.go @@ -18,27 +18,23 @@ package vmextensions import ( "context" + "encoding/json" "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2021-04-01/compute" "github.com/Azure/go-autorest/autorest" + azureautorest "github.com/Azure/go-autorest/autorest/azure" + "github.com/pkg/errors" + infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" "sigs.k8s.io/cluster-api-provider-azure/azure" + "sigs.k8s.io/cluster-api-provider-azure/util/reconciler" "sigs.k8s.io/cluster-api-provider-azure/util/tele" ) -// client wraps go-sdk. -type client interface { - Get(ctx context.Context, resourceGroupName, vmName, name string) (compute.VirtualMachineExtension, error) - CreateOrUpdateAsync(context.Context, string, string, string, compute.VirtualMachineExtension) error - Delete(context.Context, string, string, string) error -} - // azureClient contains the Azure go-sdk Client. type azureClient struct { vmextensions compute.VirtualMachineExtensionsClient } -var _ client = (*azureClient)(nil) - // newClient creates a new VM client from subscription ID. func newClient(auth azure.Authorizer) *azureClient { c := newVirtualMachineExtensionsClient(auth.SubscriptionID(), auth.BaseURI(), auth.Authorizer()) @@ -52,36 +48,118 @@ func newVirtualMachineExtensionsClient(subscriptionID string, baseURI string, au return vmextensionsClient } -// Get the virtual machine extension. -func (ac *azureClient) Get(ctx context.Context, resourceGroupName, vmName, name string) (compute.VirtualMachineExtension, error) { +// Get the specified virtual machine extension. +func (ac *azureClient) Get(ctx context.Context, spec azure.ResourceSpecGetter) (result interface{}, err error) { ctx, _, done := tele.StartSpanWithLogger(ctx, "vmextensions.AzureClient.Get") defer done() - return ac.vmextensions.Get(ctx, resourceGroupName, vmName, name, "") + return ac.vmextensions.Get(ctx, spec.ResourceGroupName(), spec.OwnerResourceName(), spec.ResourceName(), "") } -// CreateOrUpdateAsync creates or updates the virtual machine extension. -func (ac *azureClient) CreateOrUpdateAsync(ctx context.Context, resourceGroupName, vmName, name string, parameters compute.VirtualMachineExtension) error { +// CreateOrUpdateAsync creates or updates a VM extension asynchronously. +// It sends a PUT request to Azure and if accepted without error, the func will return a Future which can be used to track the ongoing +// progress of the operation. +func (ac *azureClient) CreateOrUpdateAsync(ctx context.Context, spec azure.ResourceSpecGetter, parameters interface{}) (result interface{}, future azureautorest.FutureAPI, err error) { ctx, _, done := tele.StartSpanWithLogger(ctx, "vmextensions.AzureClient.CreateOrUpdate") defer done() - _, err := ac.vmextensions.CreateOrUpdate(ctx, resourceGroupName, vmName, name, parameters) - return err + vmextension, ok := parameters.(compute.VirtualMachineExtension) + if !ok { + return nil, nil, errors.Errorf("%T is not a compute.VirtualMachineExtension", parameters) + } + + createFuture, err := ac.vmextensions.CreateOrUpdate(ctx, spec.ResourceGroupName(), spec.OwnerResourceName(), spec.ResourceName(), vmextension) + if err != nil { + return nil, nil, err + } + + ctx, cancel := context.WithTimeout(ctx, reconciler.DefaultAzureCallTimeout) + defer cancel() + + err = createFuture.WaitForCompletionRef(ctx, ac.vmextensions.Client) + if err != nil { + // if an error occurs, return the future. + // this means the long-running operation didn't finish in the specified timeout. + return nil, &createFuture, err + } + result, err = createFuture.Result(ac.vmextensions) + // if the operation completed, return a nil future. + return result, nil, err + // ctx, _, done := tele.StartSpanWithLogger(ctx, "vmextensions.AzureClient.CreateOrUpdate") + // defer done() + + // _, err := ac.vmextensions.CreateOrUpdate(ctx, resourceGroupName, vmName, name, parameters) + // return err } -// Delete removes the virtual machine extension. -func (ac *azureClient) Delete(ctx context.Context, resourceGroupName, vmName, name string) error { +// DeleteAsync deletes a VM extension asynchronously. DeleteAsync sends a DELETE +// request to Azure and if accepted without error, the func will return a Future which can be used to track the ongoing +// progress of the operation. +func (ac *azureClient) DeleteAsync(ctx context.Context, spec azure.ResourceSpecGetter) (future azureautorest.FutureAPI, err error) { ctx, _, done := tele.StartSpanWithLogger(ctx, "vmextensions.AzureClient.Delete") defer done() - future, err := ac.vmextensions.Delete(ctx, resourceGroupName, vmName, name) + deleteFuture, err := ac.vmextensions.Delete(ctx, spec.ResourceGroupName(), spec.OwnerResourceName(), spec.ResourceName()) if err != nil { - return err + return nil, err } - err = future.WaitForCompletionRef(ctx, ac.vmextensions.Client) + + ctx, cancel := context.WithTimeout(ctx, reconciler.DefaultAzureCallTimeout) + defer cancel() + + err = deleteFuture.WaitForCompletionRef(ctx, ac.vmextensions.Client) if err != nil { - return err + // if an error occurs, return the future. + // this means the long-running operation didn't finish in the specified timeout. + return &deleteFuture, err + } + _, err = deleteFuture.Result(ac.vmextensions) + // if the operation completed, return a nil future. + return nil, err +} + +// IsDone returns true if the long-running operation has completed. +func (ac *azureClient) IsDone(ctx context.Context, future azureautorest.FutureAPI) (isDone bool, err error) { + ctx, _, done := tele.StartSpanWithLogger(ctx, "virtualnetworks.azureClient.IsDone") + defer done() + + isDone, err = future.DoneWithContext(ctx, ac.vmextensions) + if err != nil { + return false, errors.Wrap(err, "failed checking if the operation was complete") + } + + return isDone, nil +} + +// Result fetches the result of a long-running operation future. +func (ac *azureClient) Result(ctx context.Context, future azureautorest.FutureAPI, futureType string) (result interface{}, err error) { + _, _, done := tele.StartSpanWithLogger(ctx, "vmextensions.azureClient.Result") + defer done() + + if future == nil { + return nil, errors.Errorf("cannot get result from nil future") + } + + switch futureType { + case infrav1.PutFuture: + // Marshal and Unmarshal the future to put it into the correct future type so we can access the Result function. + // Unfortunately the FutureAPI can't be casted directly to VirtualMachineExtensionsCreateOrUpdateFuture because it is a azureautorest.Future, which doesn't implement the Result function. See PR #1686 for discussion on alternatives. + // It was converted back to a generic azureautorest.Future from the CAPZ infrav1.Future type stored in Status: https://github.com/kubernetes-sigs/cluster-api-provider-azure/blob/main/azure/converters/futures.go#L49. + var createFuture *compute.VirtualMachineExtensionsCreateOrUpdateFuture + jsonData, err := future.MarshalJSON() + if err != nil { + return nil, errors.Wrap(err, "failed to marshal future") + } + if err := json.Unmarshal(jsonData, &createFuture); err != nil { + return nil, errors.Wrap(err, "failed to unmarshal future data") + } + return createFuture.Result(ac.vmextensions) + + case infrav1.DeleteFuture: + // Delete does not return a result vnet. + return nil, nil + + default: + return nil, errors.Errorf("unknown future type %q", futureType) } - _, err = future.Result(ac.vmextensions) - return err } diff --git a/azure/services/vmextensions/mock_vmextensions/client_mock.go b/azure/services/vmextensions/mock_vmextensions/client_mock.go deleted file mode 100644 index 1c428838660..00000000000 --- a/azure/services/vmextensions/mock_vmextensions/client_mock.go +++ /dev/null @@ -1,95 +0,0 @@ -/* -Copyright The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Code generated by MockGen. DO NOT EDIT. -// Source: ../client.go - -// Package mock_vmextensions is a generated GoMock package. -package mock_vmextensions - -import ( - context "context" - reflect "reflect" - - compute "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2021-04-01/compute" - gomock "github.com/golang/mock/gomock" -) - -// Mockclient is a mock of client interface. -type Mockclient struct { - ctrl *gomock.Controller - recorder *MockclientMockRecorder -} - -// MockclientMockRecorder is the mock recorder for Mockclient. -type MockclientMockRecorder struct { - mock *Mockclient -} - -// NewMockclient creates a new mock instance. -func NewMockclient(ctrl *gomock.Controller) *Mockclient { - mock := &Mockclient{ctrl: ctrl} - mock.recorder = &MockclientMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *Mockclient) EXPECT() *MockclientMockRecorder { - return m.recorder -} - -// CreateOrUpdateAsync mocks base method. -func (m *Mockclient) CreateOrUpdateAsync(arg0 context.Context, arg1, arg2, arg3 string, arg4 compute.VirtualMachineExtension) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateOrUpdateAsync", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(error) - return ret0 -} - -// CreateOrUpdateAsync indicates an expected call of CreateOrUpdateAsync. -func (mr *MockclientMockRecorder) CreateOrUpdateAsync(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateOrUpdateAsync", reflect.TypeOf((*Mockclient)(nil).CreateOrUpdateAsync), arg0, arg1, arg2, arg3, arg4) -} - -// Delete mocks base method. -func (m *Mockclient) Delete(arg0 context.Context, arg1, arg2, arg3 string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Delete", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// Delete indicates an expected call of Delete. -func (mr *MockclientMockRecorder) Delete(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*Mockclient)(nil).Delete), arg0, arg1, arg2, arg3) -} - -// Get mocks base method. -func (m *Mockclient) Get(ctx context.Context, resourceGroupName, vmName, name string) (compute.VirtualMachineExtension, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", ctx, resourceGroupName, vmName, name) - ret0, _ := ret[0].(compute.VirtualMachineExtension) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Get indicates an expected call of Get. -func (mr *MockclientMockRecorder) Get(ctx, resourceGroupName, vmName, name interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*Mockclient)(nil).Get), ctx, resourceGroupName, vmName, name) -} diff --git a/azure/services/vmextensions/mock_vmextensions/doc.go b/azure/services/vmextensions/mock_vmextensions/doc.go index 13054e2b2f6..ab9b313ce95 100644 --- a/azure/services/vmextensions/mock_vmextensions/doc.go +++ b/azure/services/vmextensions/mock_vmextensions/doc.go @@ -15,8 +15,6 @@ limitations under the License. */ // Run go generate to regenerate this mock. -//go:generate ../../../../hack/tools/bin/mockgen -destination client_mock.go -package mock_vmextensions -source ../client.go Client //go:generate ../../../../hack/tools/bin/mockgen -destination vmextensions_mock.go -package mock_vmextensions -source ../vmextensions.go VMExtensionScope -//go:generate /usr/bin/env bash -c "cat ../../../../hack/boilerplate/boilerplate.generatego.txt client_mock.go > _client_mock.go && mv _client_mock.go client_mock.go" //go:generate /usr/bin/env bash -c "cat ../../../../hack/boilerplate/boilerplate.generatego.txt vmextensions_mock.go > _vmextensions_mock.go && mv _vmextensions_mock.go vmextensions_mock.go" package mock_vmextensions //nolint diff --git a/azure/services/vmextensions/mock_vmextensions/vmextensions_mock.go b/azure/services/vmextensions/mock_vmextensions/vmextensions_mock.go index 298108959a8..985a6838738 100644 --- a/azure/services/vmextensions/mock_vmextensions/vmextensions_mock.go +++ b/azure/services/vmextensions/mock_vmextensions/vmextensions_mock.go @@ -28,6 +28,7 @@ import ( gomock "github.com/golang/mock/gomock" v1beta1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" azure "sigs.k8s.io/cluster-api-provider-azure/azure" + v1beta10 "sigs.k8s.io/cluster-api/api/v1beta1" ) // MockVMExtensionScope is a mock of VMExtensionScope interface. @@ -179,6 +180,18 @@ func (mr *MockVMExtensionScopeMockRecorder) ClusterName() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterName", reflect.TypeOf((*MockVMExtensionScope)(nil).ClusterName)) } +// DeleteLongRunningOperationState mocks base method. +func (m *MockVMExtensionScope) DeleteLongRunningOperationState(arg0, arg1 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DeleteLongRunningOperationState", arg0, arg1) +} + +// DeleteLongRunningOperationState indicates an expected call of DeleteLongRunningOperationState. +func (mr *MockVMExtensionScopeMockRecorder) DeleteLongRunningOperationState(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteLongRunningOperationState", reflect.TypeOf((*MockVMExtensionScope)(nil).DeleteLongRunningOperationState), arg0, arg1) +} + // FailureDomains mocks base method. func (m *MockVMExtensionScope) FailureDomains() []string { m.ctrl.T.Helper() @@ -193,6 +206,20 @@ func (mr *MockVMExtensionScopeMockRecorder) FailureDomains() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FailureDomains", reflect.TypeOf((*MockVMExtensionScope)(nil).FailureDomains)) } +// GetLongRunningOperationState mocks base method. +func (m *MockVMExtensionScope) GetLongRunningOperationState(arg0, arg1 string) *v1beta1.Future { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLongRunningOperationState", arg0, arg1) + ret0, _ := ret[0].(*v1beta1.Future) + return ret0 +} + +// GetLongRunningOperationState indicates an expected call of GetLongRunningOperationState. +func (mr *MockVMExtensionScopeMockRecorder) GetLongRunningOperationState(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLongRunningOperationState", reflect.TypeOf((*MockVMExtensionScope)(nil).GetLongRunningOperationState), arg0, arg1) +} + // HashKey mocks base method. func (m *MockVMExtensionScope) HashKey() string { m.ctrl.T.Helper() @@ -249,6 +276,18 @@ func (mr *MockVMExtensionScopeMockRecorder) SetBootstrapConditions(arg0, arg1, a return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBootstrapConditions", reflect.TypeOf((*MockVMExtensionScope)(nil).SetBootstrapConditions), arg0, arg1, arg2) } +// SetLongRunningOperationState mocks base method. +func (m *MockVMExtensionScope) SetLongRunningOperationState(arg0 *v1beta1.Future) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetLongRunningOperationState", arg0) +} + +// SetLongRunningOperationState indicates an expected call of SetLongRunningOperationState. +func (mr *MockVMExtensionScopeMockRecorder) SetLongRunningOperationState(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLongRunningOperationState", reflect.TypeOf((*MockVMExtensionScope)(nil).SetLongRunningOperationState), arg0) +} + // SubscriptionID mocks base method. func (m *MockVMExtensionScope) SubscriptionID() string { m.ctrl.T.Helper() @@ -277,11 +316,47 @@ func (mr *MockVMExtensionScopeMockRecorder) TenantID() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TenantID", reflect.TypeOf((*MockVMExtensionScope)(nil).TenantID)) } +// UpdateDeleteStatus mocks base method. +func (m *MockVMExtensionScope) UpdateDeleteStatus(arg0 v1beta10.ConditionType, arg1 string, arg2 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdateDeleteStatus", arg0, arg1, arg2) +} + +// UpdateDeleteStatus indicates an expected call of UpdateDeleteStatus. +func (mr *MockVMExtensionScopeMockRecorder) UpdateDeleteStatus(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateDeleteStatus", reflect.TypeOf((*MockVMExtensionScope)(nil).UpdateDeleteStatus), arg0, arg1, arg2) +} + +// UpdatePatchStatus mocks base method. +func (m *MockVMExtensionScope) UpdatePatchStatus(arg0 v1beta10.ConditionType, arg1 string, arg2 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatePatchStatus", arg0, arg1, arg2) +} + +// UpdatePatchStatus indicates an expected call of UpdatePatchStatus. +func (mr *MockVMExtensionScopeMockRecorder) UpdatePatchStatus(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePatchStatus", reflect.TypeOf((*MockVMExtensionScope)(nil).UpdatePatchStatus), arg0, arg1, arg2) +} + +// UpdatePutStatus mocks base method. +func (m *MockVMExtensionScope) UpdatePutStatus(arg0 v1beta10.ConditionType, arg1 string, arg2 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatePutStatus", arg0, arg1, arg2) +} + +// UpdatePutStatus indicates an expected call of UpdatePutStatus. +func (mr *MockVMExtensionScopeMockRecorder) UpdatePutStatus(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePutStatus", reflect.TypeOf((*MockVMExtensionScope)(nil).UpdatePutStatus), arg0, arg1, arg2) +} + // VMExtensionSpecs mocks base method. -func (m *MockVMExtensionScope) VMExtensionSpecs() []azure.ExtensionSpec { +func (m *MockVMExtensionScope) VMExtensionSpecs() []azure.ResourceSpecGetter { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "VMExtensionSpecs") - ret0, _ := ret[0].([]azure.ExtensionSpec) + ret0, _ := ret[0].([]azure.ResourceSpecGetter) return ret0 } diff --git a/azure/services/vmextensions/spec.go b/azure/services/vmextensions/spec.go new file mode 100644 index 00000000000..9396aa1b696 --- /dev/null +++ b/azure/services/vmextensions/spec.go @@ -0,0 +1,70 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package vmextensions + +import ( + "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2021-04-01/compute" + "github.com/Azure/go-autorest/autorest/to" + "github.com/pkg/errors" + "sigs.k8s.io/cluster-api-provider-azure/azure" +) + +// VMExtensionSpec defines the specification for a VM or VMScaleSet extension. +type VMExtensionSpec struct { + azure.ExtensionSpec + ResourceGroup string + Location string +} + +// ResourceName returns the name of the VM extension. +func (s *VMExtensionSpec) ResourceName() string { + return s.Name +} + +// ResourceGroupName returns the name of the resource group. +func (s *VMExtensionSpec) ResourceGroupName() string { + return s.ResourceGroup +} + +// OwnerResourceName returns the name of the VM that owns this VM extension. +func (s *VMExtensionSpec) OwnerResourceName() string { + return s.VMName +} + +// Parameters returns the parameters for the VM extension. +func (s *VMExtensionSpec) Parameters(existing interface{}) (interface{}, error) { + if existing != nil { + _, ok := existing.(compute.VirtualMachineExtension) + if !ok { + return nil, errors.Errorf("%T is not a compute.VirtualMachineExtension", existing) + } + + // VM extension already exists, nothing to update. + return nil, nil + } + + return compute.VirtualMachineExtension{ + VirtualMachineExtensionProperties: &compute.VirtualMachineExtensionProperties{ + Publisher: to.StringPtr(s.Publisher), + Type: to.StringPtr(s.Name), + TypeHandlerVersion: to.StringPtr(s.Version), + Settings: nil, + ProtectedSettings: s.ProtectedSettings, + }, + Location: to.StringPtr(s.Location), + }, nil +} diff --git a/azure/services/vmextensions/vmextensions.go b/azure/services/vmextensions/vmextensions.go index a06c148c2a8..a19694eb53a 100644 --- a/azure/services/vmextensions/vmextensions.go +++ b/azure/services/vmextensions/vmextensions.go @@ -23,6 +23,8 @@ import ( "github.com/Azure/go-autorest/autorest/to" "github.com/pkg/errors" "sigs.k8s.io/cluster-api-provider-azure/azure" + "sigs.k8s.io/cluster-api-provider-azure/azure/services/async" + "sigs.k8s.io/cluster-api-provider-azure/util/reconciler" "sigs.k8s.io/cluster-api-provider-azure/util/tele" ) @@ -30,22 +32,25 @@ const serviceName = "vmextensions" // VMExtensionScope defines the scope interface for a vm extension service. type VMExtensionScope interface { + azure.Authorizer + azure.AsyncStatusUpdater azure.ClusterDescriber - VMExtensionSpecs() []azure.ExtensionSpec + VMExtensionSpecs() []azure.ResourceSpecGetter SetBootstrapConditions(context.Context, string, string) error } // Service provides operations on Azure resources. type Service struct { Scope VMExtensionScope - client + async.Reconciler } // New creates a new vm extension service. func New(scope VMExtensionScope) *Service { + client := newClient(scope) return &Service{ - Scope: scope, - client: newClient(scope), + Scope: scope, + Reconciler: async.New(scope, client, client), } } @@ -56,47 +61,45 @@ func (s *Service) Name() string { // Reconcile creates or updates the VM extension. func (s *Service) Reconcile(ctx context.Context) error { - ctx, log, done := tele.StartSpanWithLogger(ctx, "vmextensions.Service.Reconcile") + ctx, _, done := tele.StartSpanWithLogger(ctx, "vmextensions.Service.Reconcile") defer done() - for _, extensionSpec := range s.Scope.VMExtensionSpecs() { - if existing, err := s.client.Get(ctx, s.Scope.ResourceGroup(), extensionSpec.VMName, extensionSpec.Name); err == nil { + ctx, cancel := context.WithTimeout(ctx, reconciler.DefaultAzureServiceReconcileTimeout) + defer cancel() + + specs := s.Scope.VMExtensionSpecs() + if len(specs) == 0 { + return nil + } + + // We go through the list of ExtensionSpecs to reconcile each one, independently of the result of the previous one. + // If multiple errors occur, we return the most pressing one. + // Order of precedence (highest -> lowest) is: error that is not an operationNotDoneError (i.e. error creating) -> operationNotDoneError (i.e. creating in progress) -> no error (i.e. created) + var resultErr error + for _, extensionSpec := range specs { + result, err := s.CreateResource(ctx, extensionSpec, serviceName) + if err != nil { + if !azure.IsOperationNotDoneError(err) || resultErr == nil { + resultErr = err + } + } else { + vmextension, ok := result.(compute.VirtualMachineExtension) + if !ok { + return errors.Errorf("%T is not a compute.VirtualMachineExtension", result) + } + // check the extension status and set the associated conditions. - if retErr := s.Scope.SetBootstrapConditions(ctx, to.String(existing.ProvisioningState), extensionSpec.Name); retErr != nil { + if retErr := s.Scope.SetBootstrapConditions(ctx, to.String(vmextension.ProvisioningState), extensionSpec.ResourceName()); retErr != nil { + // TODO: what precedence should this error have? return retErr } - // if the extension already exists, do not update it. - continue - } else if !azure.ResourceNotFound(err) { - return errors.Wrapf(err, "failed to get vm extension %s on vm %s", extensionSpec.Name, extensionSpec.VMName) - } - - log.V(2).Info("creating VM extension", "vm extension", extensionSpec.Name) - err := s.client.CreateOrUpdateAsync( - ctx, - s.Scope.ResourceGroup(), - extensionSpec.VMName, - extensionSpec.Name, - compute.VirtualMachineExtension{ - VirtualMachineExtensionProperties: &compute.VirtualMachineExtensionProperties{ - Publisher: to.StringPtr(extensionSpec.Publisher), - Type: to.StringPtr(extensionSpec.Name), - TypeHandlerVersion: to.StringPtr(extensionSpec.Version), - Settings: nil, - ProtectedSettings: extensionSpec.ProtectedSettings, - }, - Location: to.StringPtr(s.Scope.Location()), - }, - ) - if err != nil { - return errors.Wrapf(err, "failed to create VM extension %s on VM %s in resource group %s", extensionSpec.Name, extensionSpec.VMName, s.Scope.ResourceGroup()) } - log.V(2).Info("successfully created VM extension", "vm extension", extensionSpec.Name) } - return nil + + return resultErr } -// Delete is a no-op. Extensions will be deleted as part of VM deletion. +// Delete is a no-op. VM Extensions will be deleted as part of VM deletion. func (s *Service) Delete(_ context.Context) error { return nil } diff --git a/azure/services/vmextensions/vmextensions_test.go b/azure/services/vmextensions/vmextensions_test.go index 281838e928d..fa8af5097d0 100644 --- a/azure/services/vmextensions/vmextensions_test.go +++ b/azure/services/vmextensions/vmextensions_test.go @@ -27,169 +27,140 @@ import ( "github.com/golang/mock/gomock" . "github.com/onsi/gomega" "sigs.k8s.io/cluster-api-provider-azure/azure" + "sigs.k8s.io/cluster-api-provider-azure/azure/services/async/mock_async" "sigs.k8s.io/cluster-api-provider-azure/azure/services/vmextensions/mock_vmextensions" gomockinternal "sigs.k8s.io/cluster-api-provider-azure/internal/test/matchers/gomock" ) +var ( + extensionSpec1 = VMExtensionSpec{ + ExtensionSpec: azure.ExtensionSpec{ + Name: "my-extension-1", + VMName: "my-vm", + Publisher: "some-publisher", + Version: "1.0", + }, + ResourceGroup: "my-rg", + Location: "test-location", + } + + extensionSucceeded1 = compute.VirtualMachineExtension{ + VirtualMachineExtensionProperties: &compute.VirtualMachineExtensionProperties{ + Publisher: to.StringPtr("some-publisher"), + Type: to.StringPtr("my-extension-1"), + ProvisioningState: to.StringPtr(string(compute.ProvisioningStateSucceeded)), + }, + ID: to.StringPtr("fake/id"), + Name: to.StringPtr("my-extension-1"), + } + + extensionFailed1 = compute.VirtualMachineExtension{ + VirtualMachineExtensionProperties: &compute.VirtualMachineExtensionProperties{ + Publisher: to.StringPtr("some-publisher"), + Type: to.StringPtr("my-extension-1"), + ProvisioningState: to.StringPtr(string(compute.ProvisioningStateFailed)), + }, + ID: to.StringPtr("fake/id"), + Name: to.StringPtr("my-extension-1"), + } + + extensionCreating1 = compute.VirtualMachineExtension{ + VirtualMachineExtensionProperties: &compute.VirtualMachineExtensionProperties{ + Publisher: to.StringPtr("some-publisher"), + Type: to.StringPtr("my-extension-1"), + ProvisioningState: to.StringPtr(string(compute.ProvisioningStateCreating)), + }, + ID: to.StringPtr("fake/id"), + Name: to.StringPtr("my-extension-1"), + } + + extensionSpec2 = VMExtensionSpec{ + ExtensionSpec: azure.ExtensionSpec{ + Name: "my-extension-2", + VMName: "my-vm", + Publisher: "other-publisher", + Version: "2.0", + }, + ResourceGroup: "my-rg", + Location: "test-location", + } + + extensionSucceeded2 = compute.VirtualMachineExtension{ + VirtualMachineExtensionProperties: &compute.VirtualMachineExtensionProperties{ + Publisher: to.StringPtr("other-publisher"), + Type: to.StringPtr("my-extension-2"), + ProvisioningState: to.StringPtr(string(compute.ProvisioningStateSucceeded)), + }, + ID: to.StringPtr("fake/id-2"), + Name: to.StringPtr("my-extension-2"), + } + + internalError = autorest.NewErrorWithResponse("", "", &http.Response{StatusCode: 500}, "Internal Server Error") +) + func TestReconcileVMExtension(t *testing.T) { testcases := []struct { name string expectedError string - expect func(s *mock_vmextensions.MockVMExtensionScopeMockRecorder, m *mock_vmextensions.MockclientMockRecorder) + expect func(s *mock_vmextensions.MockVMExtensionScopeMockRecorder, r *mock_async.MockReconcilerMockRecorder) }{ { name: "extension is in succeeded state", expectedError: "", - expect: func(s *mock_vmextensions.MockVMExtensionScopeMockRecorder, m *mock_vmextensions.MockclientMockRecorder) { - s.VMExtensionSpecs().Return([]azure.ExtensionSpec{ - { - Name: "my-extension-1", - VMName: "my-vm", - Publisher: "some-publisher", - Version: "1.0", - }, - }) - s.ResourceGroup().AnyTimes().Return("my-rg") - s.Location().AnyTimes().Return("test-location") - m.Get(gomockinternal.AContext(), "my-rg", "my-vm", "my-extension-1").Return(compute.VirtualMachineExtension{ - VirtualMachineExtensionProperties: &compute.VirtualMachineExtensionProperties{ - Publisher: to.StringPtr("some-publisher"), - Type: to.StringPtr("my-extension-1"), - ProvisioningState: to.StringPtr(string(compute.ProvisioningStateSucceeded)), - }, - ID: to.StringPtr("fake/id"), - Name: to.StringPtr("my-extension-1"), - }, nil) - s.SetBootstrapConditions(gomockinternal.AContext(), string(compute.ProvisioningStateSucceeded), "my-extension-1") + expect: func(s *mock_vmextensions.MockVMExtensionScopeMockRecorder, r *mock_async.MockReconcilerMockRecorder) { + s.VMExtensionSpecs().Return([]azure.ResourceSpecGetter{&extensionSpec1}) + r.CreateResource(gomockinternal.AContext(), &extensionSpec1, serviceName).Return(extensionSucceeded1, nil) + s.SetBootstrapConditions(gomockinternal.AContext(), string(compute.ProvisioningStateSucceeded), extensionSpec1.ResourceName()).Return(nil) }, }, { name: "extension is in failed state", expectedError: "", - expect: func(s *mock_vmextensions.MockVMExtensionScopeMockRecorder, m *mock_vmextensions.MockclientMockRecorder) { - s.VMExtensionSpecs().Return([]azure.ExtensionSpec{ - { - Name: "my-extension-1", - VMName: "my-vm", - Publisher: "some-publisher", - Version: "1.0", - }, - }) - s.ResourceGroup().AnyTimes().Return("my-rg") - s.Location().AnyTimes().Return("test-location") - m.Get(gomockinternal.AContext(), "my-rg", "my-vm", "my-extension-1").Return(compute.VirtualMachineExtension{ - VirtualMachineExtensionProperties: &compute.VirtualMachineExtensionProperties{ - Publisher: to.StringPtr("some-publisher"), - Type: to.StringPtr("my-extension-1"), - ProvisioningState: to.StringPtr(string(compute.ProvisioningStateFailed)), - }, - ID: to.StringPtr("fake/id"), - Name: to.StringPtr("my-extension-1"), - }, nil) - s.SetBootstrapConditions(gomockinternal.AContext(), string(compute.ProvisioningStateFailed), "my-extension-1") + expect: func(s *mock_vmextensions.MockVMExtensionScopeMockRecorder, r *mock_async.MockReconcilerMockRecorder) { + s.VMExtensionSpecs().Return([]azure.ResourceSpecGetter{&extensionSpec1}) + r.CreateResource(gomockinternal.AContext(), &extensionSpec1, serviceName).Return(extensionFailed1, nil) + s.SetBootstrapConditions(gomockinternal.AContext(), string(compute.ProvisioningStateFailed), extensionSpec1.ResourceName()).Return(nil) }, }, { + // TODO: is this blocked by the operationNotDoneError? name: "extension is still creating", expectedError: "", - expect: func(s *mock_vmextensions.MockVMExtensionScopeMockRecorder, m *mock_vmextensions.MockclientMockRecorder) { - s.VMExtensionSpecs().Return([]azure.ExtensionSpec{ - { - Name: "my-extension-1", - VMName: "my-vm", - Publisher: "some-publisher", - Version: "1.0", - }, - }) - s.ResourceGroup().AnyTimes().Return("my-rg") - s.Location().AnyTimes().Return("test-location") - m.Get(gomockinternal.AContext(), "my-rg", "my-vm", "my-extension-1").Return(compute.VirtualMachineExtension{ - VirtualMachineExtensionProperties: &compute.VirtualMachineExtensionProperties{ - Publisher: to.StringPtr("some-publisher"), - Type: to.StringPtr("my-extension-1"), - ProvisioningState: to.StringPtr(string(compute.ProvisioningStateCreating)), - }, - ID: to.StringPtr("fake/id"), - Name: to.StringPtr("my-extension-1"), - }, nil) - s.SetBootstrapConditions(gomockinternal.AContext(), string(compute.ProvisioningStateCreating), "my-extension-1") + expect: func(s *mock_vmextensions.MockVMExtensionScopeMockRecorder, r *mock_async.MockReconcilerMockRecorder) { + s.VMExtensionSpecs().Return([]azure.ResourceSpecGetter{&extensionSpec1}) + r.CreateResource(gomockinternal.AContext(), &extensionSpec1, serviceName).Return(extensionCreating1, nil) + s.SetBootstrapConditions(gomockinternal.AContext(), string(compute.ProvisioningStateCreating), extensionSpec1.ResourceName()).Return(nil) }, }, { name: "reconcile multiple extensions", expectedError: "", - expect: func(s *mock_vmextensions.MockVMExtensionScopeMockRecorder, m *mock_vmextensions.MockclientMockRecorder) { - s.VMExtensionSpecs().Return([]azure.ExtensionSpec{ - { - Name: "my-extension-1", - VMName: "my-vm", - Publisher: "some-publisher", - Version: "1.0", - }, - { - Name: "other-extension", - VMName: "my-vm", - Publisher: "other-publisher", - Version: "2.0", - }, - }) - s.ResourceGroup().AnyTimes().Return("my-rg") - s.Location().AnyTimes().Return("test-location") - m.Get(gomockinternal.AContext(), "my-rg", "my-vm", "my-extension-1"). - Return(compute.VirtualMachineExtension{}, autorest.NewErrorWithResponse("", "", &http.Response{StatusCode: 404}, "Not found")) - m.CreateOrUpdateAsync(gomockinternal.AContext(), "my-rg", "my-vm", "my-extension-1", gomock.AssignableToTypeOf(compute.VirtualMachineExtension{})) - m.Get(gomockinternal.AContext(), "my-rg", "my-vm", "other-extension"). - Return(compute.VirtualMachineExtension{}, autorest.NewErrorWithResponse("", "", &http.Response{StatusCode: 404}, "Not found")) - m.CreateOrUpdateAsync(gomockinternal.AContext(), "my-rg", "my-vm", "other-extension", gomock.AssignableToTypeOf(compute.VirtualMachineExtension{})) + expect: func(s *mock_vmextensions.MockVMExtensionScopeMockRecorder, r *mock_async.MockReconcilerMockRecorder) { + s.VMExtensionSpecs().Return([]azure.ResourceSpecGetter{&extensionSpec1, &extensionSpec2}) + r.CreateResource(gomockinternal.AContext(), &extensionSpec1, serviceName).Return(extensionSucceeded1, nil) + s.SetBootstrapConditions(gomockinternal.AContext(), string(compute.ProvisioningStateSucceeded), extensionSpec1.ResourceName()).Return(nil) + r.CreateResource(gomockinternal.AContext(), &extensionSpec2, serviceName).Return(extensionSucceeded2, nil) + s.SetBootstrapConditions(gomockinternal.AContext(), string(compute.ProvisioningStateSucceeded), extensionSpec2.ResourceName()).Return(nil) }, }, { - name: "error getting the extension", - expectedError: "failed to get vm extension my-extension-1 on vm my-vm: #: Internal Server Error: StatusCode=500", - expect: func(s *mock_vmextensions.MockVMExtensionScopeMockRecorder, m *mock_vmextensions.MockclientMockRecorder) { - s.VMExtensionSpecs().Return([]azure.ExtensionSpec{ - { - Name: "my-extension-1", - VMName: "my-vm", - Publisher: "some-publisher", - Version: "1.0", - }, - { - Name: "other-extension", - VMName: "my-vm", - Publisher: "other-publisher", - Version: "2.0", - }, - }) - s.ResourceGroup().AnyTimes().Return("my-rg") - s.Location().AnyTimes().Return("test-location") - m.Get(gomockinternal.AContext(), "my-rg", "my-vm", "my-extension-1"). - Return(compute.VirtualMachineExtension{}, autorest.NewErrorWithResponse("", "", &http.Response{StatusCode: 500}, "Internal Server Error")) + name: "error creating the first extension", + expectedError: internalError.Error(), + expect: func(s *mock_vmextensions.MockVMExtensionScopeMockRecorder, r *mock_async.MockReconcilerMockRecorder) { + s.VMExtensionSpecs().Return([]azure.ResourceSpecGetter{&extensionSpec1, &extensionSpec2}) + r.CreateResource(gomockinternal.AContext(), &extensionSpec1, serviceName).Return(nil, internalError) + r.CreateResource(gomockinternal.AContext(), &extensionSpec2, serviceName).Return(extensionSucceeded2, nil) + s.SetBootstrapConditions(gomockinternal.AContext(), string(compute.ProvisioningStateSucceeded), extensionSpec2.ResourceName()).Return(nil) }, }, { - name: "error creating the extension", - expectedError: "failed to create VM extension my-extension-1 on VM my-vm in resource group my-rg: #: Internal Server Error: StatusCode=500", - expect: func(s *mock_vmextensions.MockVMExtensionScopeMockRecorder, m *mock_vmextensions.MockclientMockRecorder) { - s.VMExtensionSpecs().Return([]azure.ExtensionSpec{ - { - Name: "my-extension-1", - VMName: "my-vm", - Publisher: "some-publisher", - Version: "1.0", - }, - { - Name: "other-extension", - VMName: "my-vm", - Publisher: "other-publisher", - Version: "2.0", - }, - }) - s.ResourceGroup().AnyTimes().Return("my-rg") - s.Location().AnyTimes().Return("test-location") - m.Get(gomockinternal.AContext(), "my-rg", "my-vm", "my-extension-1"). - Return(compute.VirtualMachineExtension{}, autorest.NewErrorWithResponse("", "", &http.Response{StatusCode: 404}, "Not found")) - m.CreateOrUpdateAsync(gomockinternal.AContext(), "my-rg", "my-vm", "my-extension-1", gomock.AssignableToTypeOf(compute.VirtualMachineExtension{})).Return(autorest.NewErrorWithResponse("", "", &http.Response{StatusCode: 500}, "Internal Server Error")) + name: "error setting bootstrap conditions", + expectedError: internalError.Error(), + expect: func(s *mock_vmextensions.MockVMExtensionScopeMockRecorder, r *mock_async.MockReconcilerMockRecorder) { + s.VMExtensionSpecs().Return([]azure.ResourceSpecGetter{&extensionSpec1, &extensionSpec2}) + r.CreateResource(gomockinternal.AContext(), &extensionSpec1, serviceName).Return(extensionSucceeded1, nil) + s.SetBootstrapConditions(gomockinternal.AContext(), string(compute.ProvisioningStateSucceeded), extensionSpec1.ResourceName()).Return(internalError) + // TODO: update test depending on how errors from SetBootstrapConditions are handled }, }, } @@ -203,13 +174,13 @@ func TestReconcileVMExtension(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() scopeMock := mock_vmextensions.NewMockVMExtensionScope(mockCtrl) - clientMock := mock_vmextensions.NewMockclient(mockCtrl) + asyncMock := mock_async.NewMockReconciler(mockCtrl) - tc.expect(scopeMock.EXPECT(), clientMock.EXPECT()) + tc.expect(scopeMock.EXPECT(), asyncMock.EXPECT()) s := &Service{ - Scope: scopeMock, - client: clientMock, + Scope: scopeMock, + Reconciler: asyncMock, } err := s.Reconcile(context.TODO()) diff --git a/azure/services/vmssextensions/client.go b/azure/services/vmssextensions/client.go index 3699e07628f..4dbdc972492 100644 --- a/azure/services/vmssextensions/client.go +++ b/azure/services/vmssextensions/client.go @@ -27,7 +27,7 @@ import ( // Client wraps go-sdk. type client interface { - Get(context.Context, string, string, string) (compute.VirtualMachineScaleSetExtension, error) + Get(context.Context, azure.ResourceSpecGetter) (result interface{}, err error) } // AzureClient contains the Azure go-sdk Client. @@ -50,10 +50,10 @@ func newVirtualMachineScaleSetExtensionsClient(subscriptionID string, baseURI st return vmssextensionsClient } -// Get creates or updates the virtual machine scale set extension. -func (ac *azureClient) Get(ctx context.Context, resourceGroupName, vmssName, name string) (compute.VirtualMachineScaleSetExtension, error) { +// Get the virtual machine scale set extension. +func (ac *azureClient) Get(ctx context.Context, spec azure.ResourceSpecGetter) (result interface{}, err error) { ctx, _, done := tele.StartSpanWithLogger(ctx, "vmssextensions.AzureClient.Get") defer done() - return ac.vmssextensions.Get(ctx, resourceGroupName, vmssName, name, "") + return ac.vmssextensions.Get(ctx, spec.ResourceGroupName(), spec.OwnerResourceName(), spec.ResourceName(), "") } diff --git a/azure/services/vmssextensions/mock_vmssextensions/client_mock.go b/azure/services/vmssextensions/mock_vmssextensions/client_mock.go index 9b7bf1048f2..81d0a1dfa98 100644 --- a/azure/services/vmssextensions/mock_vmssextensions/client_mock.go +++ b/azure/services/vmssextensions/mock_vmssextensions/client_mock.go @@ -24,8 +24,8 @@ import ( context "context" reflect "reflect" - compute "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2021-04-01/compute" gomock "github.com/golang/mock/gomock" + azure "sigs.k8s.io/cluster-api-provider-azure/azure" ) // Mockclient is a mock of client interface. @@ -52,16 +52,16 @@ func (m *Mockclient) EXPECT() *MockclientMockRecorder { } // Get mocks base method. -func (m *Mockclient) Get(arg0 context.Context, arg1, arg2, arg3 string) (compute.VirtualMachineScaleSetExtension, error) { +func (m *Mockclient) Get(arg0 context.Context, arg1 azure.ResourceSpecGetter) (interface{}, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(compute.VirtualMachineScaleSetExtension) + ret := m.ctrl.Call(m, "Get", arg0, arg1) + ret0, _ := ret[0].(interface{}) ret1, _ := ret[1].(error) return ret0, ret1 } // Get indicates an expected call of Get. -func (mr *MockclientMockRecorder) Get(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockclientMockRecorder) Get(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*Mockclient)(nil).Get), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*Mockclient)(nil).Get), arg0, arg1) } diff --git a/azure/services/vmssextensions/mock_vmssextensions/vmssextensions_mock.go b/azure/services/vmssextensions/mock_vmssextensions/vmssextensions_mock.go index 40b2cb626e9..ba56d468b32 100644 --- a/azure/services/vmssextensions/mock_vmssextensions/vmssextensions_mock.go +++ b/azure/services/vmssextensions/mock_vmssextensions/vmssextensions_mock.go @@ -278,10 +278,10 @@ func (mr *MockVMSSExtensionScopeMockRecorder) TenantID() *gomock.Call { } // VMSSExtensionSpecs mocks base method. -func (m *MockVMSSExtensionScope) VMSSExtensionSpecs() []azure.ExtensionSpec { +func (m *MockVMSSExtensionScope) VMSSExtensionSpecs() []azure.ResourceSpecGetter { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "VMSSExtensionSpecs") - ret0, _ := ret[0].([]azure.ExtensionSpec) + ret0, _ := ret[0].([]azure.ResourceSpecGetter) return ret0 } diff --git a/azure/services/vmssextensions/spec.go b/azure/services/vmssextensions/spec.go new file mode 100644 index 00000000000..78c4b90d09e --- /dev/null +++ b/azure/services/vmssextensions/spec.go @@ -0,0 +1,70 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package vmssextensions + +import ( + "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2021-04-01/compute" + "github.com/Azure/go-autorest/autorest/to" + "github.com/pkg/errors" + "sigs.k8s.io/cluster-api-provider-azure/azure" +) + +// VMSSExtensionSpec defines the specification for a VM or VMScaleSet extension. +type VMSSExtensionSpec struct { + azure.ExtensionSpec + ResourceGroup string +} + +// ResourceName returns the name of the VMSS extension. +func (s *VMSSExtensionSpec) ResourceName() string { + return s.Name +} + +// ResourceGroupName returns the name of the resource group. +func (s *VMSSExtensionSpec) ResourceGroupName() string { + return s.ResourceGroup +} + +// OwnerResourceName returns the name of the VMSS that owns this VMSS extension. +func (s *VMSSExtensionSpec) OwnerResourceName() string { + return s.VMName +} + +// Parameters returns the parameters for the VMSS extension. +func (s *VMSSExtensionSpec) Parameters(existing interface{}) (interface{}, error) { + if existing != nil { + _, ok := existing.(compute.VirtualMachineScaleSetExtension) + if !ok { + return nil, errors.Errorf("%T is not a compute.VirtualMachineScaleSetExtension", existing) + } + + // VMSS extension already exists, nothing to update. + return nil, nil + } + + return compute.VirtualMachineScaleSetExtension{ + Name: to.StringPtr(s.Name), + VirtualMachineScaleSetExtensionProperties: &compute.VirtualMachineScaleSetExtensionProperties{ + Publisher: to.StringPtr(s.Publisher), + Type: to.StringPtr(s.Name), + TypeHandlerVersion: to.StringPtr(s.Version), + Settings: nil, + ProtectedSettings: s.ProtectedSettings, + }, + // TODO: should we include location since it's used in VMExtensions too? + }, nil +} diff --git a/azure/services/vmssextensions/vmssextensions.go b/azure/services/vmssextensions/vmssextensions.go index e3885600f85..d07b674ffbf 100644 --- a/azure/services/vmssextensions/vmssextensions.go +++ b/azure/services/vmssextensions/vmssextensions.go @@ -19,6 +19,7 @@ package vmssextensions import ( "context" + "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2021-04-01/compute" "github.com/Azure/go-autorest/autorest/to" "github.com/pkg/errors" "sigs.k8s.io/cluster-api-provider-azure/azure" @@ -30,7 +31,7 @@ const serviceName = "vmssextensions" // VMSSExtensionScope defines the scope interface for a vmss extension service. type VMSSExtensionScope interface { azure.ClusterDescriber - VMSSExtensionSpecs() []azure.ExtensionSpec + VMSSExtensionSpecs() []azure.ResourceSpecGetter SetBootstrapConditions(context.Context, string, string) error } @@ -59,13 +60,17 @@ func (s *Service) Reconcile(ctx context.Context) error { defer done() for _, extensionSpec := range s.Scope.VMSSExtensionSpecs() { - if existing, err := s.client.Get(ctx, s.Scope.ResourceGroup(), extensionSpec.VMName, extensionSpec.Name); err == nil { + if existing, err := s.client.Get(ctx, extensionSpec); err == nil { + extension, ok := existing.(compute.VirtualMachineScaleSetExtension) + if !ok { + return errors.Errorf("%T is not a compute.VirtualMachineScaleSetExtension", existing) + } // check the extension status and set the associated conditions. - if retErr := s.Scope.SetBootstrapConditions(ctx, to.String(existing.ProvisioningState), extensionSpec.Name); retErr != nil { + if retErr := s.Scope.SetBootstrapConditions(ctx, to.String(extension.ProvisioningState), extensionSpec.ResourceName()); retErr != nil { return retErr } } else if !azure.ResourceNotFound(err) { - return errors.Wrapf(err, "failed to get vm extension %s on scale set %s", extensionSpec.Name, extensionSpec.VMName) + return errors.Wrapf(err, "failed to get vm extension %s on scale set %s", extensionSpec.ResourceName(), extensionSpec.OwnerResourceName()) } // Nothing else to do here, the extensions are applied to the model as part of the scale set Reconcile. continue diff --git a/azure/services/vmssextensions/vmssextensions_test.go b/azure/services/vmssextensions/vmssextensions_test.go index 57c270e33da..cb5afc223fc 100644 --- a/azure/services/vmssextensions/vmssextensions_test.go +++ b/azure/services/vmssextensions/vmssextensions_test.go @@ -31,6 +31,40 @@ import ( gomockinternal "sigs.k8s.io/cluster-api-provider-azure/internal/test/matchers/gomock" ) +var ( + fakeExtensionSpec = VMSSExtensionSpec{ + ExtensionSpec: azure.ExtensionSpec{ + Name: "my-extension-1", + VMName: "my-vmss", + Publisher: "some-publisher", + Version: "1.0", + }, + ResourceGroup: "my-rg", + } + + fakeExtension = compute.VirtualMachineScaleSetExtension{ + Name: to.StringPtr("my-extension-1"), + VirtualMachineScaleSetExtensionProperties: &compute.VirtualMachineScaleSetExtensionProperties{ + Publisher: to.StringPtr("some-publisher"), + Type: to.StringPtr("my-extension-1"), + ProvisioningState: to.StringPtr(string(compute.ProvisioningStateSucceeded)), + }, + ID: to.StringPtr("some/fake/id"), + } + + fakeExtensionSpec2 = VMSSExtensionSpec{ + ExtensionSpec: azure.ExtensionSpec{ + Name: "other-extension", + VMName: "my-vmss", + Publisher: "other-publisher", + Version: "2.0", + }, + ResourceGroup: "my-rg", + } + + notFoundErr = autorest.NewErrorWithResponse("", "", &http.Response{StatusCode: 404}, "Not found") +) + func TestReconcileVMSSExtension(t *testing.T) { testcases := []struct { name string @@ -41,25 +75,8 @@ func TestReconcileVMSSExtension(t *testing.T) { name: "extension already exists", expectedError: "", expect: func(s *mock_vmssextensions.MockVMSSExtensionScopeMockRecorder, m *mock_vmssextensions.MockclientMockRecorder) { - s.VMSSExtensionSpecs().Return([]azure.ExtensionSpec{ - { - Name: "my-extension-1", - VMName: "my-vmss", - Publisher: "some-publisher", - Version: "1.0", - }, - }) - s.ResourceGroup().AnyTimes().Return("my-rg") - s.Location().AnyTimes().Return("test-location") - m.Get(gomockinternal.AContext(), "my-rg", "my-vmss", "my-extension-1").Return(compute.VirtualMachineScaleSetExtension{ - Name: to.StringPtr("my-extension-1"), - VirtualMachineScaleSetExtensionProperties: &compute.VirtualMachineScaleSetExtensionProperties{ - Publisher: to.StringPtr("some-publisher"), - Type: to.StringPtr("my-extension-1"), - ProvisioningState: to.StringPtr(string(compute.ProvisioningStateSucceeded)), - }, - ID: to.StringPtr("some/fake/id"), - }, nil) + s.VMSSExtensionSpecs().Return([]azure.ResourceSpecGetter{&fakeExtensionSpec}) + m.Get(gomockinternal.AContext(), &fakeExtensionSpec).Return(fakeExtension, nil) s.SetBootstrapConditions(gomockinternal.AContext(), string(compute.ProvisioningStateSucceeded), "my-extension-1") }, }, @@ -67,50 +84,17 @@ func TestReconcileVMSSExtension(t *testing.T) { name: "extension does not exist", expectedError: "", expect: func(s *mock_vmssextensions.MockVMSSExtensionScopeMockRecorder, m *mock_vmssextensions.MockclientMockRecorder) { - s.VMSSExtensionSpecs().Return([]azure.ExtensionSpec{ - { - Name: "my-extension-1", - VMName: "my-vmss", - Publisher: "some-publisher", - Version: "1.0", - }, - { - Name: "other-extension", - VMName: "my-vmss", - Publisher: "other-publisher", - Version: "2.0", - }, - }) - s.ResourceGroup().AnyTimes().Return("my-rg") - s.Location().AnyTimes().Return("test-location") - m.Get(gomockinternal.AContext(), "my-rg", "my-vmss", "my-extension-1"). - Return(compute.VirtualMachineScaleSetExtension{}, autorest.NewErrorWithResponse("", "", &http.Response{StatusCode: 404}, "Not found")) - m.Get(gomockinternal.AContext(), "my-rg", "my-vmss", "other-extension"). - Return(compute.VirtualMachineScaleSetExtension{}, autorest.NewErrorWithResponse("", "", &http.Response{StatusCode: 404}, "Not found")) + s.VMSSExtensionSpecs().Return([]azure.ResourceSpecGetter{&fakeExtensionSpec, &fakeExtensionSpec2}) + m.Get(gomockinternal.AContext(), &fakeExtensionSpec).Return(nil, notFoundErr) + m.Get(gomockinternal.AContext(), &fakeExtensionSpec2).Return(nil, notFoundErr) }, }, { name: "error getting the extension", expectedError: "failed to get vm extension my-extension-1 on scale set my-vmss: #: Internal Server Error: StatusCode=500", expect: func(s *mock_vmssextensions.MockVMSSExtensionScopeMockRecorder, m *mock_vmssextensions.MockclientMockRecorder) { - s.VMSSExtensionSpecs().Return([]azure.ExtensionSpec{ - { - Name: "my-extension-1", - VMName: "my-vmss", - Publisher: "some-publisher", - Version: "1.0", - }, - { - Name: "other-extension", - VMName: "my-vmss", - Publisher: "other-publisher", - Version: "2.0", - }, - }) - s.ResourceGroup().AnyTimes().Return("my-rg") - s.Location().AnyTimes().Return("test-location") - m.Get(gomockinternal.AContext(), "my-rg", "my-vmss", "my-extension-1"). - Return(compute.VirtualMachineScaleSetExtension{}, autorest.NewErrorWithResponse("", "", &http.Response{StatusCode: 500}, "Internal Server Error")) + s.VMSSExtensionSpecs().Return([]azure.ResourceSpecGetter{&fakeExtensionSpec, &fakeExtensionSpec2}) + m.Get(gomockinternal.AContext(), &fakeExtensionSpec).Return(nil, autorest.NewErrorWithResponse("", "", &http.Response{StatusCode: 500}, "Internal Server Error")) }, }, } diff --git a/azure/types.go b/azure/types.go index 08b0fed91c9..392a9c74779 100644 --- a/azure/types.go +++ b/azure/types.go @@ -96,7 +96,7 @@ type PrivateDNSLinkSpec struct { LinkName string } -// ExtensionSpec defines the specification for a VM or VMScaleSet extension. +// ExtensionSpec defines the specification for a VM or VMSS extension. type ExtensionSpec struct { Name string VMName string