diff --git a/azure/scope/machinepool.go b/azure/scope/machinepool.go index 9f13ad92ee0a..bd7d3614c740 100644 --- a/azure/scope/machinepool.go +++ b/azure/scope/machinepool.go @@ -33,6 +33,7 @@ import ( infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" "sigs.k8s.io/cluster-api-provider-azure/azure" machinepool "sigs.k8s.io/cluster-api-provider-azure/azure/scope/strategies/machinepool_deployments" + "sigs.k8s.io/cluster-api-provider-azure/azure/services/resourceskus" "sigs.k8s.io/cluster-api-provider-azure/azure/services/roleassignments" "sigs.k8s.io/cluster-api-provider-azure/azure/services/scalesets" "sigs.k8s.io/cluster-api-provider-azure/azure/services/virtualmachineimages" @@ -62,6 +63,7 @@ type ( MachinePool *expv1.MachinePool AzureMachinePool *infrav1exp.AzureMachinePool ClusterScope azure.ClusterScoper + Cache *MachinePoolCache } // MachinePoolScope defines a scope defined around a machine pool and its cluster. @@ -73,6 +75,7 @@ type ( patchHelper *patch.Helper capiMachinePoolPatchHelper *patch.Helper vmssState *azure.VMSS + cache *MachinePoolCache } // NodeStatus represents the status of a Kubernetes node. @@ -80,6 +83,14 @@ type ( Ready bool Version string } + + // MachinePoolCache stores common machine pool information so we don't have to hit the API multiple times within the same reconcile loop. + MachinePoolCache struct { + BootstrapData string + VMImage *infrav1.Image + VMSKU resourceskus.SKU + availabilitySetSKU resourceskus.SKU + } ) // NewMachinePoolScope creates a new MachinePoolScope from the supplied parameters. @@ -117,9 +128,77 @@ func NewMachinePoolScope(params MachinePoolScopeParams) (*MachinePoolScope, erro }, nil } +func (m *MachinePoolScope) InitMachinePoolCache(ctx context.Context) error { + ctx, _, done := tele.StartSpanWithLogger(ctx, "azure.MachineScope.InitMachineCache") + defer done() + + if m.cache == nil { + var err error + m.cache = &MachinePoolCache{} + + m.cache.BootstrapData, err = m.GetBootstrapData(ctx) + if err != nil { + return err + } + + m.cache.VMImage, err = m.GetVMImage(ctx) + if err != nil { + return err + } + + skuCache, err := resourceskus.GetCache(m, m.Location()) + if err != nil { + return err + } + + m.cache.VMSKU, err = skuCache.Get(ctx, m.AzureMachinePool.Spec.Template.VMSize, resourceskus.VirtualMachines) + if err != nil { + return errors.Wrapf(err, "failed to get VM SKU %s in compute api", m.AzureMachinePool.Spec.Template.VMSize) + } + + } + + return nil +} + // ScaleSetSpec returns the scale set spec. -func (m *MachinePoolScope) ScaleSetSpec() azure.ScaleSetSpec { - return azure.ScaleSetSpec{ +func (m *MachinePoolScope) ScaleSetSpec(ctx context.Context) azure.ResourceSpecGetter { + ctx, log, done := tele.StartSpanWithLogger(ctx, "scope.MachinePoolScope.ScaleSetSpec") + defer done() + + vmImage, err := m.GetVMImage(ctx) + if err != nil { + log.Error(err, "failed to get VM image") + // TODO: do we just leave this nil or return an error? + } + + bootstrapData, err := m.GetBootstrapData(ctx) + if err != nil { + log.Error(err, "failed to get bootstrap data") + // TODO: do we return early here? + } + + maxSurge, err := m.MaxSurge() + if err != nil { + log.Error(err, "failed to get max surge") + // TODO: do we return early here? + } + + shouldPatchCustomData := false + if m.HasReplicasExternallyManaged(ctx) { + shouldPatchCustomData, err := m.HasBootstrapDataChanges(ctx) + if err != nil { + log.Error(err, "failed to check for bootstrap data changes") + // return nil, errors.Wrap(err, "unable to calculate custom data hash") + } + if shouldPatchCustomData { + log.V(4).Info("custom data changed") + } else { + log.V(4).Info("custom data unchanged") + } + } + + return &scalesets.ScaleSetSpec{ Name: m.Name(), Size: m.AzureMachinePool.Spec.Template.VMSize, Capacity: int64(pointer.Int32Deref(m.MachinePool.Spec.Replicas, 0)), @@ -142,6 +221,17 @@ func (m *MachinePoolScope) ScaleSetSpec() azure.ScaleSetSpec { NetworkInterfaces: m.AzureMachinePool.Spec.Template.NetworkInterfaces, IPv6Enabled: m.IsIPv6Enabled(), OrchestrationMode: m.AzureMachinePool.Spec.OrchestrationMode, + Location: m.AzureMachinePool.Spec.Location, + SubscriptionID: m.SubscriptionID(), + VMSSExtensionSpecs: m.VMSSExtensionSpecs(), + VMImage: vmImage, + BootstrapData: bootstrapData, + ClusterName: m.ClusterName(), + AdditionalTags: m.AzureMachinePool.Spec.AdditionalTags, + MaxSurge: maxSurge, + SKU: m.cache.VMSKU, + ShouldPatchCustomData: shouldPatchCustomData, + // VMSSInstances []compute.VirtualMachineScaleSetVM } } diff --git a/azure/services/roleassignments/roleassignments.go b/azure/services/roleassignments/roleassignments.go index 99dd72f1d3e4..027fc7cffb40 100644 --- a/azure/services/roleassignments/roleassignments.go +++ b/azure/services/roleassignments/roleassignments.go @@ -47,7 +47,7 @@ type Service struct { Scope RoleAssignmentScope virtualMachinesGetter async.Getter async.Reconciler - virtualMachineScaleSetClient scalesets.Client + virtualMachineScaleSetGetter async.Getter } // New creates a new service. @@ -56,7 +56,7 @@ func New(scope RoleAssignmentScope) *Service { return &Service{ Scope: scope, virtualMachinesGetter: virtualmachines.NewClient(scope), - virtualMachineScaleSetClient: scalesets.NewClient(scope), + virtualMachineScaleSetGetter: scalesets.NewClient(scope), Reconciler: async.New(scope, client, client), } } @@ -141,10 +141,20 @@ func (s *Service) getVMSSPrincipalID(ctx context.Context) (*string, error) { ctx, log, done := tele.StartSpanWithLogger(ctx, "roleassignments.Service.getVMPrincipalID") defer done() log.V(2).Info("fetching principal ID for VMSS") - resultVMSS, err := s.virtualMachineScaleSetClient.Get(ctx, s.Scope.ResourceGroup(), s.Scope.Name()) + spec := &scalesets.ScaleSetSpec{ + Name: s.Scope.Name(), + ResourceGroup: s.Scope.ResourceGroup(), + } + + resultVMSSIface, err := s.virtualMachineScaleSetGetter.Get(ctx, spec) if err != nil { return nil, errors.Wrap(err, "failed to get principal ID for VMSS") } + resultVMSS, ok := resultVMSSIface.(compute.VirtualMachineScaleSet) + if !ok { + return nil, errors.Errorf("%T is not a compute.VirtualMachine", resultVMSSIface) + } + return resultVMSS.Identity.PrincipalID, nil } diff --git a/azure/services/roleassignments/roleassignments_test.go b/azure/services/roleassignments/roleassignments_test.go index ca542f2ecb4a..8d760efbe651 100644 --- a/azure/services/roleassignments/roleassignments_test.go +++ b/azure/services/roleassignments/roleassignments_test.go @@ -30,6 +30,7 @@ import ( "sigs.k8s.io/cluster-api-provider-azure/azure" "sigs.k8s.io/cluster-api-provider-azure/azure/services/async/mock_async" "sigs.k8s.io/cluster-api-provider-azure/azure/services/roleassignments/mock_roleassignments" + "sigs.k8s.io/cluster-api-provider-azure/azure/services/scalesets" "sigs.k8s.io/cluster-api-provider-azure/azure/services/scalesets/mock_scalesets" "sigs.k8s.io/cluster-api-provider-azure/azure/services/virtualmachines" gomockinternal "sigs.k8s.io/cluster-api-provider-azure/internal/test/matchers/gomock" @@ -55,6 +56,10 @@ var ( emptyRoleAssignmentSpec = RoleAssignmentSpec{} fakeRoleAssignmentSpecs = []azure.ResourceSpecGetter{&fakeRoleAssignment1, &fakeRoleAssignment2, &emptyRoleAssignmentSpec} + fakeVMSSSpec = scalesets.ScaleSetSpec{ + Name: "test-vmss", + ResourceGroup: "my-rg", + } ) func TestReconcileRoleAssignmentsVM(t *testing.T) { @@ -169,7 +174,7 @@ func TestReconcileRoleAssignmentsVMSS(t *testing.T) { s.RoleAssignmentResourceType().Return(azure.VirtualMachineScaleSet) s.ResourceGroup().Return("my-rg") s.Name().Return("test-vmss") - mvmss.Get(gomockinternal.AContext(), "my-rg", "test-vmss").Return(compute.VirtualMachineScaleSet{ + mvmss.Get(gomockinternal.AContext(), &fakeVMSSSpec).Return(compute.VirtualMachineScaleSet{ Identity: &compute.VirtualMachineScaleSetIdentity{ PrincipalID: &fakePrincipalID, }, @@ -187,7 +192,7 @@ func TestReconcileRoleAssignmentsVMSS(t *testing.T) { s.ResourceGroup().Return("my-rg") s.Name().Return("test-vmss") s.HasSystemAssignedIdentity().Return(true) - mvmss.Get(gomockinternal.AContext(), "my-rg", "test-vmss").Return(compute.VirtualMachineScaleSet{}, + mvmss.Get(gomockinternal.AContext(), &fakeVMSSSpec).Return(compute.VirtualMachineScaleSet{}, autorest.NewErrorWithResponse("", "", &http.Response{StatusCode: http.StatusInternalServerError}, "Internal Server Error")) }, }, @@ -202,7 +207,7 @@ func TestReconcileRoleAssignmentsVMSS(t *testing.T) { s.RoleAssignmentResourceType().Return(azure.VirtualMachineScaleSet) s.ResourceGroup().Return("my-rg") s.Name().Return("test-vmss") - mvmss.Get(gomockinternal.AContext(), "my-rg", "test-vmss").Return(compute.VirtualMachineScaleSet{ + mvmss.Get(gomockinternal.AContext(), &fakeVMSSSpec).Return(compute.VirtualMachineScaleSet{ Identity: &compute.VirtualMachineScaleSetIdentity{ PrincipalID: &fakePrincipalID, }, @@ -229,7 +234,7 @@ func TestReconcileRoleAssignmentsVMSS(t *testing.T) { s := &Service{ Scope: scopeMock, Reconciler: asyncMock, - virtualMachineScaleSetClient: vmMock, + virtualMachineScaleSetGetter: vmMock, } err := s.Reconcile(context.TODO()) diff --git a/azure/services/scalesets/client.go b/azure/services/scalesets/client.go index b74dbe8fb5b3..c7292a4342e4 100644 --- a/azure/services/scalesets/client.go +++ b/azure/services/scalesets/client.go @@ -18,10 +18,8 @@ package scalesets import ( "context" - "encoding/base64" "encoding/json" "fmt" - "time" "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2021-11-01/compute" "github.com/Azure/go-autorest/autorest" @@ -30,21 +28,21 @@ import ( "k8s.io/utils/pointer" infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" "sigs.k8s.io/cluster-api-provider-azure/azure" - "sigs.k8s.io/cluster-api-provider-azure/azure/converters" "sigs.k8s.io/cluster-api-provider-azure/util/reconciler" "sigs.k8s.io/cluster-api-provider-azure/util/tele" ) // Client wraps go-sdk. type Client interface { + Get(context.Context, azure.ResourceSpecGetter) (interface{}, error) List(context.Context, string) ([]compute.VirtualMachineScaleSet, error) - ListInstances(context.Context, string, string) ([]compute.VirtualMachineScaleSetVM, error) - Get(context.Context, string, string) (compute.VirtualMachineScaleSet, error) - CreateOrUpdateAsync(context.Context, string, string, compute.VirtualMachineScaleSet) (*infrav1.Future, error) - UpdateAsync(context.Context, string, string, compute.VirtualMachineScaleSetUpdate) (*infrav1.Future, error) - GetResultIfDone(ctx context.Context, future *infrav1.Future) (compute.VirtualMachineScaleSet, error) + ListInstances(context.Context, azure.ResourceSpecGetter) ([]compute.VirtualMachineScaleSetVM, error) UpdateInstances(context.Context, string, string, []string) error - DeleteAsync(context.Context, string, string) (*infrav1.Future, error) + + CreateOrUpdateAsync(ctx context.Context, spec azure.ResourceSpecGetter, parameters interface{}) (result interface{}, future azureautorest.FutureAPI, err error) + DeleteAsync(ctx context.Context, spec azure.ResourceSpecGetter) (future azureautorest.FutureAPI, err error) + IsDone(ctx context.Context, future azureautorest.FutureAPI) (isDone bool, err error) + Result(ctx context.Context, future azureautorest.FutureAPI, futureType string) (result interface{}, err error) } type ( @@ -94,11 +92,11 @@ func newVirtualMachineScaleSetsClient(subscriptionID string, baseURI string, aut } // ListInstances retrieves information about the model views of a virtual machine scale set. -func (ac *AzureClient) ListInstances(ctx context.Context, resourceGroupName, vmssName string) ([]compute.VirtualMachineScaleSetVM, error) { +func (ac *AzureClient) ListInstances(ctx context.Context, spec azure.ResourceSpecGetter) ([]compute.VirtualMachineScaleSetVM, error) { ctx, _, done := tele.StartSpanWithLogger(ctx, "scalesets.AzureClient.ListInstances") defer done() - itr, err := ac.scalesetvms.ListComplete(ctx, resourceGroupName, vmssName, "", "", "") + itr, err := ac.scalesetvms.ListComplete(ctx, spec.ResourceGroupName(), spec.ResourceName(), "", "", "") if err != nil { return nil, err } @@ -136,132 +134,136 @@ func (ac *AzureClient) List(ctx context.Context, resourceGroupName string) ([]co } // Get retrieves information about the model view of a virtual machine scale set. -func (ac *AzureClient) Get(ctx context.Context, resourceGroupName, vmssName string) (compute.VirtualMachineScaleSet, error) { +func (ac *AzureClient) Get(ctx context.Context, spec azure.ResourceSpecGetter) (interface{}, error) { ctx, _, done := tele.StartSpanWithLogger(ctx, "scalesets.AzureClient.Get") defer done() - return ac.scalesets.Get(ctx, resourceGroupName, vmssName, "") + return ac.scalesets.Get(ctx, spec.ResourceGroupName(), spec.ResourceName(), "") } -// CreateOrUpdateAsync the operation to create or update a virtual machine scale set without waiting for the operation -// to complete. -func (ac *AzureClient) CreateOrUpdateAsync(ctx context.Context, resourceGroupName, vmssName string, vmss compute.VirtualMachineScaleSet) (*infrav1.Future, error) { +// CreateOrUpdateAsync creates or updates a virtual machine scale set asynchronously. +// It sends a PUT request to Azure and if accepted without error, the func will return a Future which can be used to track the ongoing +// progress of the operation. +func (ac *AzureClient) CreateOrUpdateAsync(ctx context.Context, spec azure.ResourceSpecGetter, parameters interface{}) (result interface{}, future azureautorest.FutureAPI, err error) { ctx, _, done := tele.StartSpanWithLogger(ctx, "scalesets.AzureClient.CreateOrUpdateAsync") defer done() - future, err := ac.scalesets.CreateOrUpdate(ctx, resourceGroupName, vmssName, vmss) - if err != nil { - return nil, err + scaleset, ok := parameters.(compute.VirtualMachineScaleSet) + if !ok { + return nil, nil, errors.Errorf("%T is not a compute.VirtualMachineScaleSet", parameters) } - ctx, cancel := context.WithTimeout(ctx, reconciler.DefaultAzureCallTimeout) - defer cancel() - - err = future.WaitForCompletionRef(ctx, ac.scalesets.Client) + createFuture, err := ac.scalesets.CreateOrUpdate(ctx, spec.ResourceGroupName(), spec.ResourceName(), scaleset) if err != nil { - // if an error occurs, return the future. - // this means the long-running operation didn't finish in the specified timeout. - return converters.SDKToFuture(&future, infrav1.PutFuture, serviceName, vmssName, resourceGroupName) - } - - // todo: this returns the result VMSS, we should use it - _, err = future.Result(ac.scalesets) - - // if the operation completed, return a nil future. - return nil, err -} - -// UpdateAsync update a VM scale set without waiting for the result of the operation. UpdateAsync sends a PATCH -// request to Azure and if accepted without error, the func will return a Future which can be used to track the ongoing -// progress of the operation. -// -// Parameters: -// -// resourceGroupName - the name of the resource group. -// vmssName - the name of the VM scale set to create or update. parameters - the scale set object. -func (ac *AzureClient) UpdateAsync(ctx context.Context, resourceGroupName, vmssName string, parameters compute.VirtualMachineScaleSetUpdate) (*infrav1.Future, error) { - ctx, _, done := tele.StartSpanWithLogger(ctx, "scalesets.AzureClient.UpdateAsync") - defer done() - - future, err := ac.scalesets.Update(ctx, resourceGroupName, vmssName, parameters) - if err != nil { - return nil, errors.Wrapf(err, "failed updating vmss named %q", vmssName) + return nil, nil, err } ctx, cancel := context.WithTimeout(ctx, reconciler.DefaultAzureCallTimeout) defer cancel() - err = future.WaitForCompletionRef(ctx, ac.scalesets.Client) + err = createFuture.WaitForCompletionRef(ctx, ac.scalesets.Client) if err != nil { // if an error occurs, return the future. // this means the long-running operation didn't finish in the specified timeout. - return converters.SDKToFuture(&future, infrav1.PatchFuture, serviceName, vmssName, resourceGroupName) + return nil, &createFuture, err } - // todo: this returns the result VMSS, we should use it - _, err = future.Result(ac.scalesets) - // if the operation completed, return a nil future. - return nil, err + result, err = createFuture.Result(ac.scalesets) + // if the operation completed, return a nil future + return result, nil, err } -// GetResultIfDone fetches the result of a long-running operation future if it is done. -func (ac *AzureClient) GetResultIfDone(ctx context.Context, future *infrav1.Future) (compute.VirtualMachineScaleSet, error) { - var genericFuture genericScaleSetFuture - futureData, err := base64.URLEncoding.DecodeString(future.Data) - if err != nil { - return compute.VirtualMachineScaleSet{}, errors.Wrap(err, "failed to base64 decode future data") - } - - switch future.Type { - case infrav1.PatchFuture: - var future compute.VirtualMachineScaleSetsUpdateFuture - if err := json.Unmarshal(futureData, &future); err != nil { - return compute.VirtualMachineScaleSet{}, errors.Wrap(err, "failed to unmarshal future data") - } - - genericFuture = &genericScaleSetFutureImpl{ - FutureAPI: &future, - result: future.Result, - } - case infrav1.PutFuture: - var future compute.VirtualMachineScaleSetsCreateOrUpdateFuture - if err := json.Unmarshal(futureData, &future); err != nil { - return compute.VirtualMachineScaleSet{}, errors.Wrap(err, "failed to unmarshal future data") - } - - genericFuture = &genericScaleSetFutureImpl{ - FutureAPI: &future, - result: future.Result, - } - case infrav1.DeleteFuture: - var future compute.VirtualMachineScaleSetsDeleteFuture - if err := json.Unmarshal(futureData, &future); err != nil { - return compute.VirtualMachineScaleSet{}, errors.Wrap(err, "failed to unmarshal future data") - } - - genericFuture = &deleteResultAdapter{ - VirtualMachineScaleSetsDeleteFuture: future, - } - default: - return compute.VirtualMachineScaleSet{}, errors.Errorf("unknown future type %q", future.Type) - } - - done, err := genericFuture.DoneWithContext(ctx, ac.scalesets) - if err != nil { - return compute.VirtualMachineScaleSet{}, errors.Wrap(err, "failed checking if the operation was complete") - } - - if !done { - return compute.VirtualMachineScaleSet{}, azure.WithTransientError(azure.NewOperationNotDoneError(future), 15*time.Second) - } - - vmss, err := genericFuture.Result(ac.scalesets) - if err != nil { - return vmss, errors.Wrap(err, "failed fetching the result of operation for vmss") - } - - return vmss, nil -} +// // UpdateAsync update a VM scale set without waiting for the result of the operation. UpdateAsync sends a PATCH +// // request to Azure and if accepted without error, the func will return a Future which can be used to track the ongoing +// // progress of the operation. +// // +// // Parameters: +// // +// // resourceGroupName - the name of the resource group. +// // vmssName - the name of the VM scale set to create or update. parameters - the scale set object. +// func (ac *AzureClient) UpdateAsync(ctx context.Context, resourceGroupName, vmssName string, parameters compute.VirtualMachineScaleSetUpdate) (*infrav1.Future, error) { +// ctx, _, done := tele.StartSpanWithLogger(ctx, "scalesets.AzureClient.UpdateAsync") +// defer done() + +// future, err := ac.scalesets.Update(ctx, resourceGroupName, vmssName, parameters) +// if err != nil { +// return nil, errors.Wrapf(err, "failed updating vmss named %q", vmssName) +// } + +// ctx, cancel := context.WithTimeout(ctx, reconciler.DefaultAzureCallTimeout) +// defer cancel() + +// err = future.WaitForCompletionRef(ctx, ac.scalesets.Client) +// if err != nil { +// // if an error occurs, return the future. +// // this means the long-running operation didn't finish in the specified timeout. +// return converters.SDKToFuture(&future, infrav1.PatchFuture, serviceName, vmssName, resourceGroupName) +// } +// // todo: this returns the result VMSS, we should use it +// _, err = future.Result(ac.scalesets) + +// // if the operation completed, return a nil future. +// return nil, err +// } + +// // GetResultIfDone fetches the result of a long-running operation future if it is done. +// func (ac *AzureClient) GetResultIfDone(ctx context.Context, future *infrav1.Future) (compute.VirtualMachineScaleSet, error) { +// var genericFuture genericScaleSetFuture +// futureData, err := base64.URLEncoding.DecodeString(future.Data) +// if err != nil { +// return compute.VirtualMachineScaleSet{}, errors.Wrap(err, "failed to base64 decode future data") +// } + +// switch future.Type { +// // case infrav1.PatchFuture: +// // var future compute.VirtualMachineScaleSetsUpdateFuture +// // if err := json.Unmarshal(futureData, &future); err != nil { +// // return compute.VirtualMachineScaleSet{}, errors.Wrap(err, "failed to unmarshal future data") +// // } + +// // genericFuture = &genericScaleSetFutureImpl{ +// // FutureAPI: &future, +// // result: future.Result, +// // } +// // case infrav1.PutFuture: +// // var future compute.VirtualMachineScaleSetsCreateOrUpdateFuture +// // if err := json.Unmarshal(futureData, &future); err != nil { +// // return compute.VirtualMachineScaleSet{}, errors.Wrap(err, "failed to unmarshal future data") +// // } + +// // genericFuture = &genericScaleSetFutureImpl{ +// // FutureAPI: &future, +// // result: future.Result, +// // } +// case infrav1.DeleteFuture: +// var future compute.VirtualMachineScaleSetsDeleteFuture +// if err := json.Unmarshal(futureData, &future); err != nil { +// return compute.VirtualMachineScaleSet{}, errors.Wrap(err, "failed to unmarshal future data") +// } + +// genericFuture = &deleteResultAdapter{ +// VirtualMachineScaleSetsDeleteFuture: future, +// } +// default: +// return compute.VirtualMachineScaleSet{}, errors.Errorf("unknown future type %q", future.Type) +// } + +// done, err := genericFuture.DoneWithContext(ctx, ac.scalesets) +// if err != nil { +// return compute.VirtualMachineScaleSet{}, errors.Wrap(err, "failed checking if the operation was complete") +// } + +// if !done { +// return compute.VirtualMachineScaleSet{}, azure.WithTransientError(azure.NewOperationNotDoneError(future), 15*time.Second) +// } + +// vmss, err := genericFuture.Result(ac.scalesets) +// if err != nil { +// return vmss, errors.Wrap(err, "failed fetching the result of operation for vmss") +// } + +// return vmss, nil +// } // UpdateInstances update instances of a VM scale set. func (ac *AzureClient) UpdateInstances(ctx context.Context, resourceGroupName, vmssName string, instanceIDs []string) error { @@ -289,40 +291,92 @@ func (ac *AzureClient) UpdateInstances(ctx context.Context, resourceGroupName, v // // Parameters: // -// resourceGroupName - the name of the resource group. -// vmssName - the name of the VM scale set to create or update. parameters - the scale set object. -func (ac *AzureClient) DeleteAsync(ctx context.Context, resourceGroupName, vmssName string) (*infrav1.Future, error) { +// spec - The ResourceSpecGetter containing used for name and resource group of the virutal machine scale set. +func (ac *AzureClient) DeleteAsync(ctx context.Context, spec azure.ResourceSpecGetter) (future azureautorest.FutureAPI, err error) { ctx, _, done := tele.StartSpanWithLogger(ctx, "scalesets.AzureClient.DeleteAsync") defer done() - future, err := ac.scalesets.Delete(ctx, resourceGroupName, vmssName, pointer.Bool(false)) + deleteFuture, err := ac.scalesets.Delete(ctx, spec.ResourceGroupName(), spec.ResourceName(), pointer.Bool(false)) if err != nil { - return nil, errors.Wrapf(err, "failed deleting vmss named %q", vmssName) + return nil, err } ctx, cancel := context.WithTimeout(ctx, reconciler.DefaultAzureCallTimeout) defer cancel() - err = future.WaitForCompletionRef(ctx, ac.scalesets.Client) + err = deleteFuture.WaitForCompletionRef(ctx, ac.scalesets.Client) if err != nil { // if an error occurs, return the future. // this means the long-running operation didn't finish in the specified timeout. - return converters.SDKToFuture(&future, infrav1.DeleteFuture, serviceName, vmssName, resourceGroupName) + return &deleteFuture, err } - _, err = future.Result(ac.scalesets) - + _, err = deleteFuture.Result(ac.scalesets) // if the operation completed, return a nil future. return nil, err } -// Result wraps the delete result so that we can treat it generically. The only thing we care about is if the delete -// was successful. If it wasn't, an error will be returned. -func (da *deleteResultAdapter) Result(client compute.VirtualMachineScaleSetsClient) (compute.VirtualMachineScaleSet, error) { - _, err := da.VirtualMachineScaleSetsDeleteFuture.Result(client) - return compute.VirtualMachineScaleSet{}, err +// IsDone returns true if the long-running operation has completed. +func (ac *AzureClient) IsDone(ctx context.Context, future azureautorest.FutureAPI) (bool, error) { + ctx, _, done := tele.StartSpanWithLogger(ctx, "scalesets.AzureClient.IsDone") + defer done() + + return future.DoneWithContext(ctx, ac.scalesets) } -// Result returns the Result so that we can treat it generically. -func (g *genericScaleSetFutureImpl) Result(client compute.VirtualMachineScaleSetsClient) (compute.VirtualMachineScaleSet, error) { - return g.result(client) +// Result fetches the result of a long-running operation future. +func (ac *AzureClient) Result(ctx context.Context, future azureautorest.FutureAPI, futureType string) (result interface{}, err error) { + _, _, done := tele.StartSpanWithLogger(ctx, "scalesets.AzureClient.Result") + defer done() + + if future == nil { + return nil, errors.Errorf("cannot get result from nil future") + } + + switch futureType { + case infrav1.PatchFuture: + // Marshal and Unmarshal the future to put it into the correct future type so we can access the Result function. + // Unfortunately the FutureAPI can't be casted directly to VirtualMachineScaleSetsUpdateFuture because it is a azureautorest.Future, which doesn't implement the Result function. See PR #1686 for discussion on alternatives. + // It was converted back to a generic azureautorest.Future from the CAPZ infrav1.Future type stored in Status: https://github.com/kubernetes-sigs/cluster-api-provider-azure/blob/main/azure/converters/futures.go#L49. + var updateFuture *compute.VirtualMachineScaleSetsUpdateFuture + jsonData, err := future.MarshalJSON() + if err != nil { + return nil, errors.Wrap(err, "failed to marshal future") + } + if err := json.Unmarshal(jsonData, &updateFuture); err != nil { + return nil, errors.Wrap(err, "failed to unmarshal future data") + } + return updateFuture.Result(ac.scalesets) + + case infrav1.PutFuture: + // Marshal and Unmarshal the future to put it into the correct future type so we can access the Result function. + // Unfortunately the FutureAPI can't be casted directly to VirtualMachineScaleSetsCreateOrUpdateFuture because it is a azureautorest.Future, which doesn't implement the Result function. See PR #1686 for discussion on alternatives. + // It was converted back to a generic azureautorest.Future from the CAPZ infrav1.Future type stored in Status: https://github.com/kubernetes-sigs/cluster-api-provider-azure/blob/main/azure/converters/futures.go#L49. + var createFuture *compute.VirtualMachineScaleSetsCreateOrUpdateFuture + jsonData, err := future.MarshalJSON() + if err != nil { + return nil, errors.Wrap(err, "failed to marshal future") + } + if err := json.Unmarshal(jsonData, &createFuture); err != nil { + return nil, errors.Wrap(err, "failed to unmarshal future data") + } + return createFuture.Result(ac.scalesets) + + case infrav1.DeleteFuture: + // Delete does not return a result compute.VirtualMachineScaleSet + return nil, nil + default: + return nil, errors.Errorf("unknown future type %q", futureType) + } } + +// // Result wraps the delete result so that we can treat it generically. The only thing we care about is if the delete +// // was successful. If it wasn't, an error will be returned. +// func (da *deleteResultAdapter) Result(client compute.VirtualMachineScaleSetsClient) (compute.VirtualMachineScaleSet, error) { +// _, err := da.VirtualMachineScaleSetsDeleteFuture.Result(client) +// return compute.VirtualMachineScaleSet{}, err +// } + +// // Result returns the Result so that we can treat it generically. +// func (g *genericScaleSetFutureImpl) Result(client compute.VirtualMachineScaleSetsClient) (compute.VirtualMachineScaleSet, error) { +// return g.result(client) +// } diff --git a/azure/services/scalesets/mock_scalesets/client_mock.go b/azure/services/scalesets/mock_scalesets/client_mock.go index 7d7539ceb196..c52b602a44f9 100644 --- a/azure/services/scalesets/mock_scalesets/client_mock.go +++ b/azure/services/scalesets/mock_scalesets/client_mock.go @@ -26,8 +26,9 @@ import ( compute "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2021-11-01/compute" autorest "github.com/Azure/go-autorest/autorest" + azure "github.com/Azure/go-autorest/autorest/azure" gomock "github.com/golang/mock/gomock" - v1beta1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" + azure0 "sigs.k8s.io/cluster-api-provider-azure/azure" ) // MockClient is a mock of Client interface. @@ -54,63 +55,64 @@ func (m *MockClient) EXPECT() *MockClientMockRecorder { } // CreateOrUpdateAsync mocks base method. -func (m *MockClient) CreateOrUpdateAsync(arg0 context.Context, arg1, arg2 string, arg3 compute.VirtualMachineScaleSet) (*v1beta1.Future, error) { +func (m *MockClient) CreateOrUpdateAsync(ctx context.Context, spec azure0.ResourceSpecGetter, parameters interface{}) (interface{}, azure.FutureAPI, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateOrUpdateAsync", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(*v1beta1.Future) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret := m.ctrl.Call(m, "CreateOrUpdateAsync", ctx, spec, parameters) + ret0, _ := ret[0].(interface{}) + ret1, _ := ret[1].(azure.FutureAPI) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // CreateOrUpdateAsync indicates an expected call of CreateOrUpdateAsync. -func (mr *MockClientMockRecorder) CreateOrUpdateAsync(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockClientMockRecorder) CreateOrUpdateAsync(ctx, spec, parameters interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateOrUpdateAsync", reflect.TypeOf((*MockClient)(nil).CreateOrUpdateAsync), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateOrUpdateAsync", reflect.TypeOf((*MockClient)(nil).CreateOrUpdateAsync), ctx, spec, parameters) } // DeleteAsync mocks base method. -func (m *MockClient) DeleteAsync(arg0 context.Context, arg1, arg2 string) (*v1beta1.Future, error) { +func (m *MockClient) DeleteAsync(ctx context.Context, spec azure0.ResourceSpecGetter) (azure.FutureAPI, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteAsync", arg0, arg1, arg2) - ret0, _ := ret[0].(*v1beta1.Future) + ret := m.ctrl.Call(m, "DeleteAsync", ctx, spec) + ret0, _ := ret[0].(azure.FutureAPI) ret1, _ := ret[1].(error) return ret0, ret1 } // DeleteAsync indicates an expected call of DeleteAsync. -func (mr *MockClientMockRecorder) DeleteAsync(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockClientMockRecorder) DeleteAsync(ctx, spec interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAsync", reflect.TypeOf((*MockClient)(nil).DeleteAsync), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAsync", reflect.TypeOf((*MockClient)(nil).DeleteAsync), ctx, spec) } // Get mocks base method. -func (m *MockClient) Get(arg0 context.Context, arg1, arg2 string) (compute.VirtualMachineScaleSet, error) { +func (m *MockClient) Get(arg0 context.Context, arg1 azure0.ResourceSpecGetter) (interface{}, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", arg0, arg1, arg2) - ret0, _ := ret[0].(compute.VirtualMachineScaleSet) + ret := m.ctrl.Call(m, "Get", arg0, arg1) + ret0, _ := ret[0].(interface{}) ret1, _ := ret[1].(error) return ret0, ret1 } // Get indicates an expected call of Get. -func (mr *MockClientMockRecorder) Get(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockClientMockRecorder) Get(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockClient)(nil).Get), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockClient)(nil).Get), arg0, arg1) } -// GetResultIfDone mocks base method. -func (m *MockClient) GetResultIfDone(ctx context.Context, future *v1beta1.Future) (compute.VirtualMachineScaleSet, error) { +// IsDone mocks base method. +func (m *MockClient) IsDone(ctx context.Context, future azure.FutureAPI) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetResultIfDone", ctx, future) - ret0, _ := ret[0].(compute.VirtualMachineScaleSet) + ret := m.ctrl.Call(m, "IsDone", ctx, future) + ret0, _ := ret[0].(bool) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetResultIfDone indicates an expected call of GetResultIfDone. -func (mr *MockClientMockRecorder) GetResultIfDone(ctx, future interface{}) *gomock.Call { +// IsDone indicates an expected call of IsDone. +func (mr *MockClientMockRecorder) IsDone(ctx, future interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResultIfDone", reflect.TypeOf((*MockClient)(nil).GetResultIfDone), ctx, future) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsDone", reflect.TypeOf((*MockClient)(nil).IsDone), ctx, future) } // List mocks base method. @@ -129,33 +131,33 @@ func (mr *MockClientMockRecorder) List(arg0, arg1 interface{}) *gomock.Call { } // ListInstances mocks base method. -func (m *MockClient) ListInstances(arg0 context.Context, arg1, arg2 string) ([]compute.VirtualMachineScaleSetVM, error) { +func (m *MockClient) ListInstances(arg0 context.Context, arg1 azure0.ResourceSpecGetter) ([]compute.VirtualMachineScaleSetVM, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListInstances", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "ListInstances", arg0, arg1) ret0, _ := ret[0].([]compute.VirtualMachineScaleSetVM) ret1, _ := ret[1].(error) return ret0, ret1 } // ListInstances indicates an expected call of ListInstances. -func (mr *MockClientMockRecorder) ListInstances(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockClientMockRecorder) ListInstances(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListInstances", reflect.TypeOf((*MockClient)(nil).ListInstances), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListInstances", reflect.TypeOf((*MockClient)(nil).ListInstances), arg0, arg1) } -// UpdateAsync mocks base method. -func (m *MockClient) UpdateAsync(arg0 context.Context, arg1, arg2 string, arg3 compute.VirtualMachineScaleSetUpdate) (*v1beta1.Future, error) { +// Result mocks base method. +func (m *MockClient) Result(ctx context.Context, future azure.FutureAPI, futureType string) (interface{}, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateAsync", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(*v1beta1.Future) + ret := m.ctrl.Call(m, "Result", ctx, future, futureType) + ret0, _ := ret[0].(interface{}) ret1, _ := ret[1].(error) return ret0, ret1 } -// UpdateAsync indicates an expected call of UpdateAsync. -func (mr *MockClientMockRecorder) UpdateAsync(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +// Result indicates an expected call of Result. +func (mr *MockClientMockRecorder) Result(ctx, future, futureType interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAsync", reflect.TypeOf((*MockClient)(nil).UpdateAsync), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Result", reflect.TypeOf((*MockClient)(nil).Result), ctx, future, futureType) } // UpdateInstances mocks base method. diff --git a/azure/services/scalesets/mock_scalesets/scalesets_mock.go b/azure/services/scalesets/mock_scalesets/scalesets_mock.go index 0234d95837e8..36921283f429 100644 --- a/azure/services/scalesets/mock_scalesets/scalesets_mock.go +++ b/azure/services/scalesets/mock_scalesets/scalesets_mock.go @@ -292,35 +292,6 @@ func (mr *MockScaleSetScopeMockRecorder) GetVMImage(arg0 interface{}) *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVMImage", reflect.TypeOf((*MockScaleSetScope)(nil).GetVMImage), arg0) } -// HasBootstrapDataChanges mocks base method. -func (m *MockScaleSetScope) HasBootstrapDataChanges(arg0 context.Context) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HasBootstrapDataChanges", arg0) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// HasBootstrapDataChanges indicates an expected call of HasBootstrapDataChanges. -func (mr *MockScaleSetScopeMockRecorder) HasBootstrapDataChanges(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasBootstrapDataChanges", reflect.TypeOf((*MockScaleSetScope)(nil).HasBootstrapDataChanges), arg0) -} - -// HasReplicasExternallyManaged mocks base method. -func (m *MockScaleSetScope) HasReplicasExternallyManaged(arg0 context.Context) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HasReplicasExternallyManaged", arg0) - ret0, _ := ret[0].(bool) - return ret0 -} - -// HasReplicasExternallyManaged indicates an expected call of HasReplicasExternallyManaged. -func (mr *MockScaleSetScopeMockRecorder) HasReplicasExternallyManaged(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasReplicasExternallyManaged", reflect.TypeOf((*MockScaleSetScope)(nil).HasReplicasExternallyManaged), arg0) -} - // HashKey mocks base method. func (m *MockScaleSetScope) HashKey() string { m.ctrl.T.Helper() @@ -349,21 +320,6 @@ func (mr *MockScaleSetScopeMockRecorder) Location() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Location", reflect.TypeOf((*MockScaleSetScope)(nil).Location)) } -// MaxSurge mocks base method. -func (m *MockScaleSetScope) MaxSurge() (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MaxSurge") - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// MaxSurge indicates an expected call of MaxSurge. -func (mr *MockScaleSetScopeMockRecorder) MaxSurge() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaxSurge", reflect.TypeOf((*MockScaleSetScope)(nil).MaxSurge)) -} - // ReconcileReplicas mocks base method. func (m *MockScaleSetScope) ReconcileReplicas(arg0 context.Context, arg1 *azure.VMSS) error { m.ctrl.T.Helper() @@ -405,10 +361,10 @@ func (mr *MockScaleSetScopeMockRecorder) SaveVMImageToStatus(arg0 interface{}) * } // ScaleSetSpec mocks base method. -func (m *MockScaleSetScope) ScaleSetSpec() azure.ScaleSetSpec { +func (m *MockScaleSetScope) ScaleSetSpec() azure.ResourceSpecGetter { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ScaleSetSpec") - ret0, _ := ret[0].(azure.ScaleSetSpec) + ret0, _ := ret[0].(azure.ResourceSpecGetter) return ret0 } diff --git a/azure/services/scalesets/scalesets.go b/azure/services/scalesets/scalesets.go index 7aa0e90a22ef..9e62df99b1ae 100644 --- a/azure/services/scalesets/scalesets.go +++ b/azure/services/scalesets/scalesets.go @@ -18,20 +18,17 @@ package scalesets import ( "context" - "encoding/base64" "fmt" - "strconv" - "time" "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2021-11-01/compute" "github.com/pkg/errors" - "k8s.io/utils/pointer" azprovider "sigs.k8s.io/cloud-provider-azure/pkg/provider" infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" "sigs.k8s.io/cluster-api-provider-azure/azure" "sigs.k8s.io/cluster-api-provider-azure/azure/converters" + "sigs.k8s.io/cluster-api-provider-azure/azure/services/async" "sigs.k8s.io/cluster-api-provider-azure/azure/services/resourceskus" - "sigs.k8s.io/cluster-api-provider-azure/util/generators" + "sigs.k8s.io/cluster-api-provider-azure/util/reconciler" "sigs.k8s.io/cluster-api-provider-azure/util/slice" "sigs.k8s.io/cluster-api-provider-azure/util/tele" ) @@ -46,15 +43,12 @@ type ( GetBootstrapData(context.Context) (string, error) GetVMImage(context.Context) (*infrav1.Image, error) SaveVMImageToStatus(*infrav1.Image) - MaxSurge() (int, error) - ScaleSetSpec() azure.ScaleSetSpec + ScaleSetSpec() azure.ResourceSpecGetter VMSSExtensionSpecs() []azure.ResourceSpecGetter SetAnnotation(string, string) SetProviderID(string) SetVMSSState(*azure.VMSS) ReconcileReplicas(context.Context, *azure.VMSS) error - HasReplicasExternallyManaged(context.Context) bool - HasBootstrapDataChanges(context.Context) (bool, error) } // Service provides operations on Azure resources. @@ -62,13 +56,16 @@ type ( Scope ScaleSetScope Client resourceSKUCache *resourceskus.Cache + async.Reconciler } ) // New creates a new service. func New(scope ScaleSetScope, skuCache *resourceskus.Cache) *Service { + client := NewClient(scope) return &Service{ - Client: NewClient(scope), + Reconciler: async.New(scope, client, client), + Client: client, Scope: scope, resourceSKUCache: skuCache, } @@ -84,6 +81,10 @@ func (s *Service) Reconcile(ctx context.Context) (retErr error) { ctx, log, done := tele.StartSpanWithLogger(ctx, "scalesets.Service.Reconcile") defer done() + // TODO: do we want this timeout? + ctx, cancel := context.WithTimeout(ctx, reconciler.DefaultAzureServiceReconcileTimeout) + defer cancel() + if err := s.validateSpec(ctx); err != nil { // do as much early validation as possible to limit calls to Azure return err @@ -91,245 +92,91 @@ func (s *Service) Reconcile(ctx context.Context) (retErr error) { var err error - scaleSetSpec := s.Scope.ScaleSetSpec() - - // check if there is an ongoing long running operation - var fetchedVMSS *azure.VMSS - future := s.Scope.GetLongRunningOperationState(s.Scope.ScaleSetSpec().Name, serviceName, infrav1.PutFuture) - if future == nil { - future = s.Scope.GetLongRunningOperationState(s.Scope.ScaleSetSpec().Name, serviceName, infrav1.PatchFuture) - } - - defer func() { - // save the updated state of the VMSS for the MachinePoolScope to use for updating K8s state - if fetchedVMSS == nil { - fetchedVMSS, err = s.getVirtualMachineScaleSet(ctx, scaleSetSpec.Name) - if err != nil && !azure.ResourceNotFound(err) { - log.Error(err, "failed to get vmss in deferred update") - } - } - - if fetchedVMSS != nil { - // Transform the VMSS resource representation to conform to the cloud-provider-azure representation - providerID, err := azprovider.ConvertResourceGroupNameToLower(azure.ProviderIDPrefix + fetchedVMSS.ID) - if err != nil { - log.Error(err, "failed to parse VMSS ID", "ID", fetchedVMSS.ID) - } - s.Scope.SetProviderID(providerID) - s.Scope.SetVMSSState(fetchedVMSS) - } - }() + // What's up with this circular logic?? - if future == nil { - fetchedVMSS, err = s.getVirtualMachineScaleSet(ctx, scaleSetSpec.Name) - } else { - fetchedVMSS, err = s.getVirtualMachineScaleSetIfDone(ctx, future) - } + spec := s.Scope.ScaleSetSpec() - switch { - case err != nil && !azure.ResourceNotFound(err): - // There was an error and it was not an HTTP 404 not found. This is either a transient error, like long running operation not done, or an Azure service error. - return errors.Wrapf(err, "failed to get VMSS %s", scaleSetSpec.Name) - case err != nil && azure.ResourceNotFound(err): - // HTTP(404) resource was not found, so we need to create it with a PUT - future, err = s.createVMSS(ctx) - if err != nil { - return errors.Wrap(err, "failed to start creating VMSS") - } - case err == nil: - // HTTP(200) - // VMSS already exists and may have changes; update it with a PATCH - // we do this to avoid overwriting fields in networkProfile modified by cloud-provider - future, err = s.patchVMSSIfNeeded(ctx, fetchedVMSS) - if err != nil { - return errors.Wrap(err, "failed to start updating VMSS") - } + vmssInstances, err := s.Client.ListInstances(ctx, spec) + if err != nil { + result := errors.Wrapf(err, "failed to get existing instances rules") + s.Scope.UpdatePutStatus(infrav1.InboundNATRulesReadyCondition, serviceName, result) + return result } - // Try to get the VMSS to update status if we have created a long running operation. If the VMSS is still in a long - // running operation, getVirtualMachineScaleSetIfDone will return an azure.WithTransientError and requeue. - if future != nil { - fetchedVMSS, err = s.getVirtualMachineScaleSetIfDone(ctx, future) - if err != nil { - return errors.Wrapf(err, "failed to get VMSS %s after create or update", scaleSetSpec.Name) - } + scaleSetSpec, ok := spec.(*ScaleSetSpec) + if !ok { + return errors.Errorf("%T is not of type ScaleSetSpec", spec) } - // If we get to here, we have completed any long running VMSS operations (creates / updates) - s.Scope.DeleteLongRunningOperationState(s.Scope.ScaleSetSpec().Name, serviceName, infrav1.PutFuture) - s.Scope.DeleteLongRunningOperationState(s.Scope.ScaleSetSpec().Name, serviceName, infrav1.PatchFuture) - - // This also means that the VMSS extensions were successfully installed - // Note: we want to handle UpdatePutStatus when VMSSExtensions have an error when scalesets become an async service - s.Scope.UpdatePutStatus(infrav1.BootstrapSucceededCondition, serviceName, nil) - - return nil -} - -// Delete deletes a scale set asynchronously. Delete sends a DELETE request to Azure and if accepted without error, -// the VMSS will be considered deleted. The actual delete in Azure may take longer, but should eventually complete. -func (s *Service) Delete(ctx context.Context) error { - ctx, log, done := tele.StartSpanWithLogger(ctx, "scalesets.Service.Delete") - defer done() - - var err error + scaleSetSpec.VMSSInstances = vmssInstances - vmssSpec := s.Scope.ScaleSetSpec() - - defer func() { - // save the updated state of the VMSS for the MachinePoolScope to use for updating K8s state - fetchedVMSS, err := s.getVirtualMachineScaleSet(ctx, vmssSpec.Name) - if err != nil && !azure.ResourceNotFound(err) { - log.Error(err, "failed to get vmss in deferred update") - } - - if fetchedVMSS != nil { - s.Scope.SetVMSSState(fetchedVMSS) - } - }() - - // check if there is an ongoing long running operation - future := s.Scope.GetLongRunningOperationState(vmssSpec.Name, serviceName, infrav1.DeleteFuture) - if future != nil { - // if the operation is not complete this will return an error - _, err := s.GetResultIfDone(ctx, future) - if err != nil { - return errors.Wrap(err, "failed to get result from future") - } - - // ScaleSet has been deleted - s.Scope.DeleteLongRunningOperationState(vmssSpec.Name, serviceName, infrav1.DeleteFuture) - // Note: we want to handle UpdateDeleteStatus when VMSSExtensions have an error when scalesets become an async service - s.Scope.UpdateDeleteStatus(infrav1.BootstrapSucceededCondition, serviceName, nil) - - return nil - } - - // no long running delete operation is active, so delete the ScaleSet - log.V(2).Info("deleting VMSS", "scale set", vmssSpec.Name) - future, err = s.Client.DeleteAsync(ctx, s.Scope.ResourceGroup(), vmssSpec.Name) + _, err = s.CreateOrUpdateResource(ctx, scaleSetSpec, serviceName) if err != nil { - if azure.ResourceNotFound(err) { - // already deleted - return nil - } - return errors.Wrapf(err, "failed to delete VMSS %s in resource group %s", vmssSpec.Name, s.Scope.ResourceGroup()) + // TODO: ??? + s.Scope.UpdatePutStatus(infrav1.BootstrapSucceededCondition, serviceName, err) } - s.Scope.SetLongRunningOperationState(future) - if future != nil { - // if future exists, check state of the future - if _, err = s.GetResultIfDone(ctx, future); err != nil { - return errors.Wrap(err, "not done with long running operation, or failed to get result") - } + // TODO: figure out how to + image, err := s.Scope.GetVMImage(ctx) + if err != nil { + return errors.Wrap(err, "failed to get VM image") } - // future is either nil, or the result of the future is complete - s.Scope.DeleteLongRunningOperationState(vmssSpec.Name, serviceName, infrav1.DeleteFuture) - // Note: we want to handle UpdateDeleteStatus when VMSSExtensions have an error when scalesets become an async service - s.Scope.UpdateDeleteStatus(infrav1.BootstrapSucceededCondition, serviceName, nil) - - return nil -} + s.Scope.SaveVMImageToStatus(image) -func (s *Service) createVMSS(ctx context.Context) (*infrav1.Future, error) { - ctx, log, done := tele.StartSpanWithLogger(ctx, "scalesets.Service.createVMSS") - defer done() + fetchedVMSS, err := s.getVirtualMachineScaleSet(ctx, scaleSetSpec) + if err != nil && !azure.ResourceNotFound(err) { + log.Error(err, "failed to get vmss in deferred update") + } - spec := s.Scope.ScaleSetSpec() + if fetchedVMSS != nil { + if err := s.Scope.ReconcileReplicas(ctx, fetchedVMSS); err != nil { + // TODO: move this so we don't return + s.Scope.UpdatePutStatus(infrav1.BootstrapSucceededCondition, serviceName, err) + return errors.Wrap(err, "unable to reconcile replicas") + } - vmss, err := s.buildVMSSFromSpec(ctx, spec) - if err != nil { - return nil, errors.Wrap(err, "failed building VMSS from spec") + // Transform the VMSS resource representation to conform to the cloud-provider-azure representation + providerID, err := azprovider.ConvertResourceGroupNameToLower(azure.ProviderIDPrefix + fetchedVMSS.ID) + if err != nil { + log.Error(err, "failed to parse VMSS ID", "ID", fetchedVMSS.ID) + } + s.Scope.SetProviderID(providerID) + s.Scope.SetVMSSState(fetchedVMSS) } - future, err := s.Client.CreateOrUpdateAsync(ctx, s.Scope.ResourceGroup(), spec.Name, vmss) - if err != nil { - return nil, errors.Wrap(err, "cannot create VMSS") - } + s.Scope.UpdatePutStatus(infrav1.BootstrapSucceededCondition, serviceName, err) - log.V(2).Info("starting to create VMSS", "scale set", spec.Name) - s.Scope.SetLongRunningOperationState(future) - return future, err + return err } -func (s *Service) patchVMSSIfNeeded(ctx context.Context, infraVMSS *azure.VMSS) (*infrav1.Future, error) { - ctx, log, done := tele.StartSpanWithLogger(ctx, "scalesets.Service.patchVMSSIfNeeded") +// Delete deletes a scale set asynchronously. Delete sends a DELETE request to Azure and if accepted without error, +// the VMSS will be considered deleted. The actual delete in Azure may take longer, but should eventually complete. +func (s *Service) Delete(ctx context.Context) error { + ctx, log, done := tele.StartSpanWithLogger(ctx, "scalesets.Service.Delete") defer done() - if err := s.Scope.ReconcileReplicas(ctx, infraVMSS); err != nil { - return nil, errors.Wrap(err, "unable to reconcile replicas") - } - - spec := s.Scope.ScaleSetSpec() - - vmss, err := s.buildVMSSFromSpec(ctx, spec) - if err != nil { - return nil, errors.Wrapf(err, "failed to generate scale set update parameters for %s", spec.Name) - } - - patch, err := getVMSSUpdateFromVMSS(vmss) - if err != nil { - return nil, errors.Wrapf(err, "failed to generate vmss patch for %s", spec.Name) - } - - maxSurge, err := s.Scope.MaxSurge() - if err != nil { - return nil, errors.Wrap(err, "failed to calculate maxSurge") - } + // TODO: do we want this timeout? + ctx, cancel := context.WithTimeout(ctx, reconciler.DefaultAzureServiceReconcileTimeout) + defer cancel() - // If the VMSS is managed by an external autoscaler, we should patch the VMSS if customData has changed. - shouldPatchCustomData := false - if s.Scope.HasReplicasExternallyManaged(ctx) { - shouldPatchCustomData, err = s.Scope.HasBootstrapDataChanges(ctx) - if err != nil { - return nil, errors.Wrap(err, "unable to calculate custom data hash") - } - if shouldPatchCustomData { - log.V(4).Info("custom data changed") - } else { - log.V(4).Info("custom data unchanged") - } - } + scaleSetSpec := s.Scope.ScaleSetSpec() + err := s.DeleteResource(ctx, scaleSetSpec, serviceName) + // TODO: update the VMSS state in the scope - hasModelChanges := hasModelModifyingDifferences(infraVMSS, vmss) - isFlex := s.Scope.ScaleSetSpec().OrchestrationMode == infrav1.FlexibleOrchestrationMode - updated := true - if !isFlex { - updated = infraVMSS.HasEnoughLatestModelOrNotMixedModel() - } - if maxSurge > 0 && (hasModelChanges || !updated) && !s.Scope.HasReplicasExternallyManaged(ctx) { - // surge capacity with the intention of lowering during instance reconciliation - surge := spec.Capacity + int64(maxSurge) - log.V(4).Info("surging...", "surge", surge, "hasModelChanges", hasModelChanges, "updated", updated) - patch.Sku.Capacity = pointer.Int64(surge) - } + s.Scope.UpdateDeleteStatus(infrav1.BootstrapSucceededCondition, serviceName, err) - // If the VMSS is managed by an external autoscaler, we should patch the VMSS if customData has changed. - // If there are no model changes and no increase in the replica count, do not update the VMSS. - // Decreases in replica count is handled by deleting AzureMachinePoolMachine instances in the MachinePoolScope - if *patch.Sku.Capacity <= infraVMSS.Capacity && !hasModelChanges && !shouldPatchCustomData { - log.V(4).Info("nothing to update on vmss", "scale set", spec.Name, "newReplicas", *patch.Sku.Capacity, "oldReplicas", infraVMSS.Capacity, "hasModelChanges", hasModelChanges, "shouldPatchCustomData", shouldPatchCustomData) - return nil, nil + fetchedVMSS, err := s.getVirtualMachineScaleSet(ctx, scaleSetSpec) + if err != nil && !azure.ResourceNotFound(err) { + log.Error(err, "failed to get vmss in deferred update") } - log.V(4).Info("patching vmss", "scale set", spec.Name, "patch", patch) - future, err := s.UpdateAsync(ctx, s.Scope.ResourceGroup(), spec.Name, patch) - if err != nil { - if azure.ResourceConflict(err) { - return nil, azure.WithTransientError(err, 30*time.Second) - } - return nil, errors.Wrap(err, "failed updating VMSS") + if fetchedVMSS != nil { + s.Scope.SetVMSSState(fetchedVMSS) } - s.Scope.SetLongRunningOperationState(future) - log.V(2).Info("successfully started to update vmss", "scale set", spec.Name) - return future, err -} + return err -func hasModelModifyingDifferences(infraVMSS *azure.VMSS, vmss compute.VirtualMachineScaleSet) bool { - other := converters.SDKToVMSS(vmss, []compute.VirtualMachineScaleSetVM{}) - return infraVMSS.HasModelChanges(*other) } func (s *Service) validateSpec(ctx context.Context) error { @@ -337,10 +184,14 @@ func (s *Service) validateSpec(ctx context.Context) error { defer done() spec := s.Scope.ScaleSetSpec() + scaleSetSpec, ok := spec.(*ScaleSetSpec) + if !ok { + return errors.Errorf("%T is not a ScaleSetSpec", spec) + } - sku, err := s.resourceSKUCache.Get(ctx, spec.Size, resourceskus.VirtualMachines) + sku, err := s.resourceSKUCache.Get(ctx, scaleSetSpec.Size, resourceskus.VirtualMachines) if err != nil { - return errors.Wrapf(err, "failed to get SKU %s in compute api", spec.Size) + return errors.Wrapf(err, "failed to get SKU %s in compute api", scaleSetSpec.Size) } // Checking if the requested VM size has at least 2 vCPUS @@ -364,12 +215,12 @@ func (s *Service) validateSpec(ctx context.Context) error { } // enable ephemeral OS - if spec.OSDisk.DiffDiskSettings != nil && !sku.HasCapability(resourceskus.EphemeralOSDisk) { - return azure.WithTerminalError(fmt.Errorf("vm size %s does not support ephemeral os. select a different vm size or disable ephemeral os", spec.Size)) + if scaleSetSpec.OSDisk.DiffDiskSettings != nil && !sku.HasCapability(resourceskus.EphemeralOSDisk) { + return azure.WithTerminalError(fmt.Errorf("vm size %s does not support ephemeral os. select a different vm size or disable ephemeral os", scaleSetSpec.Size)) } - if spec.SecurityProfile != nil && !sku.HasCapability(resourceskus.EncryptionAtHost) { - return azure.WithTerminalError(errors.Errorf("encryption at host is not supported for VM type %s", spec.Size)) + if scaleSetSpec.SecurityProfile != nil && !sku.HasCapability(resourceskus.EncryptionAtHost) { + return azure.WithTerminalError(errors.Errorf("encryption at host is not supported for VM type %s", scaleSetSpec.Size)) } // Fetch location and zone to check for their support of ultra disks. @@ -381,10 +232,10 @@ func (s *Service) validateSpec(ctx context.Context) error { for _, zone := range zones { hasLocationCapability := sku.HasLocationCapability(resourceskus.UltraSSDAvailable, location, zone) - err := fmt.Errorf("vm size %s does not support ultra disks in location %s. select a different vm size or disable ultra disks", spec.Size, location) + err := fmt.Errorf("vm size %s does not support ultra disks in location %s. select a different vm size or disable ultra disks", scaleSetSpec.Size, location) // Check support for ultra disks as data disks. - for _, disks := range spec.DataDisks { + for _, disks := range scaleSetSpec.DataDisks { if disks.ManagedDisk != nil && disks.ManagedDisk.StorageAccountType == string(compute.StorageAccountTypesUltraSSDLRS) && !hasLocationCapability { @@ -392,8 +243,8 @@ func (s *Service) validateSpec(ctx context.Context) error { } } // Check support for ultra disks as persistent volumes. - if spec.AdditionalCapabilities != nil && spec.AdditionalCapabilities.UltraSSDEnabled != nil { - if *spec.AdditionalCapabilities.UltraSSDEnabled && + if scaleSetSpec.AdditionalCapabilities != nil && scaleSetSpec.AdditionalCapabilities.UltraSSDEnabled != nil { + if *scaleSetSpec.AdditionalCapabilities.UltraSSDEnabled && !hasLocationCapability { return azure.WithTerminalError(err) } @@ -401,11 +252,11 @@ func (s *Service) validateSpec(ctx context.Context) error { } // Validate DiagnosticProfile spec - if spec.DiagnosticsProfile != nil && spec.DiagnosticsProfile.Boot != nil { - if spec.DiagnosticsProfile.Boot.StorageAccountType == infrav1.UserManagedDiagnosticsStorage { - if spec.DiagnosticsProfile.Boot.UserManaged == nil { + if scaleSetSpec.DiagnosticsProfile != nil && scaleSetSpec.DiagnosticsProfile.Boot != nil { + if scaleSetSpec.DiagnosticsProfile.Boot.StorageAccountType == infrav1.UserManagedDiagnosticsStorage { + if scaleSetSpec.DiagnosticsProfile.Boot.UserManaged == nil { return azure.WithTerminalError(fmt.Errorf("userManaged must be specified when storageAccountType is '%s'", infrav1.UserManagedDiagnosticsStorage)) - } else if spec.DiagnosticsProfile.Boot.UserManaged.StorageAccountURI == "" { + } else if scaleSetSpec.DiagnosticsProfile.Boot.UserManaged.StorageAccountURI == "" { return azure.WithTerminalError(fmt.Errorf("storageAccountURI cannot be empty when storageAccountType is '%s'", infrav1.UserManagedDiagnosticsStorage)) } } @@ -416,264 +267,42 @@ func (s *Service) validateSpec(ctx context.Context) error { string(infrav1.UserManagedDiagnosticsStorage), } - if !slice.Contains(possibleStorageAccountTypeValues, string(spec.DiagnosticsProfile.Boot.StorageAccountType)) { + if !slice.Contains(possibleStorageAccountTypeValues, string(scaleSetSpec.DiagnosticsProfile.Boot.StorageAccountType)) { return azure.WithTerminalError(fmt.Errorf("invalid storageAccountType: %s. Allowed values are %v", - spec.DiagnosticsProfile.Boot.StorageAccountType, possibleStorageAccountTypeValues)) + scaleSetSpec.DiagnosticsProfile.Boot.StorageAccountType, possibleStorageAccountTypeValues)) } } // Checking if selected availability zones are available selected VM type in location - azsInLocation, err := s.resourceSKUCache.GetZonesWithVMSize(ctx, spec.Size, s.Scope.Location()) + azsInLocation, err := s.resourceSKUCache.GetZonesWithVMSize(ctx, scaleSetSpec.Size, s.Scope.Location()) if err != nil { - return errors.Wrapf(err, "failed to get zones for VM type %s in location %s", spec.Size, s.Scope.Location()) + return errors.Wrapf(err, "failed to get zones for VM type %s in location %s", scaleSetSpec.Size, s.Scope.Location()) } - for _, az := range spec.FailureDomains { + for _, az := range scaleSetSpec.FailureDomains { if !slice.Contains(azsInLocation, az) { - return azure.WithTerminalError(errors.Errorf("availability zone %s is not available for VM type %s in location %s", az, spec.Size, s.Scope.Location())) + return azure.WithTerminalError(errors.Errorf("availability zone %s is not available for VM type %s in location %s", az, scaleSetSpec.Size, s.Scope.Location())) } } return nil } -func (s *Service) buildVMSSFromSpec(ctx context.Context, vmssSpec azure.ScaleSetSpec) (compute.VirtualMachineScaleSet, error) { - ctx, _, done := tele.StartSpanWithLogger(ctx, "scalesets.Service.buildVMSSFromSpec") - defer done() - - sku, err := s.resourceSKUCache.Get(ctx, vmssSpec.Size, resourceskus.VirtualMachines) - if err != nil { - return compute.VirtualMachineScaleSet{}, errors.Wrapf(err, "failed to get find SKU %s in compute api", vmssSpec.Size) - } - - if vmssSpec.AcceleratedNetworking == nil { - // set accelerated networking to the capability of the VMSize - accelNet := sku.HasCapability(resourceskus.AcceleratedNetworking) - vmssSpec.AcceleratedNetworking = &accelNet - } - - extensions, err := s.generateExtensions(ctx) - if err != nil { - return compute.VirtualMachineScaleSet{}, err - } - - storageProfile, err := s.generateStorageProfile(ctx, vmssSpec, sku) - if err != nil { - return compute.VirtualMachineScaleSet{}, err - } - - securityProfile, err := getSecurityProfile(vmssSpec, sku) - if err != nil { - return compute.VirtualMachineScaleSet{}, err - } - - priority, evictionPolicy, billingProfile, err := converters.GetSpotVMOptions(vmssSpec.SpotVMOptions, vmssSpec.OSDisk.DiffDiskSettings) - if err != nil { - return compute.VirtualMachineScaleSet{}, errors.Wrapf(err, "failed to get Spot VM options") - } - - diagnosticsProfile := converters.GetDiagnosticsProfile(vmssSpec.DiagnosticsProfile) - - osProfile, err := s.generateOSProfile(ctx, vmssSpec) - if err != nil { - return compute.VirtualMachineScaleSet{}, err - } - - orchestrationMode := converters.GetOrchestrationMode(s.Scope.ScaleSetSpec().OrchestrationMode) - vmss := compute.VirtualMachineScaleSet{ - Location: pointer.String(s.Scope.Location()), - Sku: &compute.Sku{ - Name: pointer.String(vmssSpec.Size), - Tier: pointer.String("Standard"), - Capacity: pointer.Int64(vmssSpec.Capacity), - }, - Zones: &vmssSpec.FailureDomains, - Plan: s.generateImagePlan(ctx), - VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ - OrchestrationMode: orchestrationMode, - SinglePlacementGroup: pointer.Bool(false), - VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{ - OsProfile: osProfile, - StorageProfile: storageProfile, - SecurityProfile: securityProfile, - DiagnosticsProfile: diagnosticsProfile, - NetworkProfile: &compute.VirtualMachineScaleSetNetworkProfile{ - NetworkInterfaceConfigurations: s.getVirtualMachineScaleSetNetworkConfiguration(vmssSpec), - }, - Priority: priority, - EvictionPolicy: evictionPolicy, - BillingProfile: billingProfile, - ExtensionProfile: &compute.VirtualMachineScaleSetExtensionProfile{ - Extensions: &extensions, - }, - }, - }, - } - - // Set properties specific to VMSS orchestration mode - switch orchestrationMode { - case compute.OrchestrationModeUniform: - vmss.VirtualMachineScaleSetProperties.Overprovision = pointer.Bool(false) - vmss.VirtualMachineScaleSetProperties.UpgradePolicy = &compute.UpgradePolicy{Mode: compute.UpgradeModeManual} - case compute.OrchestrationModeFlexible: - vmss.VirtualMachineScaleSetProperties.VirtualMachineProfile.NetworkProfile.NetworkAPIVersion = - compute.NetworkAPIVersionTwoZeroTwoZeroHyphenMinusOneOneHyphenMinusZeroOne - vmss.VirtualMachineScaleSetProperties.PlatformFaultDomainCount = pointer.Int32(1) - if len(vmssSpec.FailureDomains) > 1 { - vmss.VirtualMachineScaleSetProperties.PlatformFaultDomainCount = pointer.Int32(int32(len(vmssSpec.FailureDomains))) - } - } - - // Assign Identity to VMSS - if vmssSpec.Identity == infrav1.VMIdentitySystemAssigned { - vmss.Identity = &compute.VirtualMachineScaleSetIdentity{ - Type: compute.ResourceIdentityTypeSystemAssigned, - } - } else if vmssSpec.Identity == infrav1.VMIdentityUserAssigned { - userIdentitiesMap, err := converters.UserAssignedIdentitiesToVMSSSDK(vmssSpec.UserAssignedIdentities) - if err != nil { - return vmss, errors.Wrapf(err, "failed to assign identity %q", vmssSpec.Name) - } - vmss.Identity = &compute.VirtualMachineScaleSetIdentity{ - Type: compute.ResourceIdentityTypeUserAssigned, - UserAssignedIdentities: userIdentitiesMap, - } - } - - // Provisionally detect whether there is any Data Disk defined which uses UltraSSDs. - // If that's the case, enable the UltraSSD capability. - for _, dataDisk := range vmssSpec.DataDisks { - if dataDisk.ManagedDisk != nil && dataDisk.ManagedDisk.StorageAccountType == string(compute.StorageAccountTypesUltraSSDLRS) { - vmss.VirtualMachineScaleSetProperties.AdditionalCapabilities = &compute.AdditionalCapabilities{ - UltraSSDEnabled: pointer.Bool(true), - } - } - } - - // Set Additional Capabilities if any is present on the spec. - if vmssSpec.AdditionalCapabilities != nil { - // Set UltraSSDEnabled if a specific value is set on the spec for it. - if vmssSpec.AdditionalCapabilities.UltraSSDEnabled != nil { - vmss.AdditionalCapabilities.UltraSSDEnabled = vmssSpec.AdditionalCapabilities.UltraSSDEnabled - } - } - - if vmssSpec.TerminateNotificationTimeout != nil { - vmss.VirtualMachineScaleSetProperties.VirtualMachineProfile.ScheduledEventsProfile = &compute.ScheduledEventsProfile{ - TerminateNotificationProfile: &compute.TerminateNotificationProfile{ - NotBeforeTimeout: pointer.String(fmt.Sprintf("PT%dM", *vmssSpec.TerminateNotificationTimeout)), - Enable: pointer.Bool(true), - }, - } - } - - tags := infrav1.Build(infrav1.BuildParams{ - ClusterName: s.Scope.ClusterName(), - Lifecycle: infrav1.ResourceLifecycleOwned, - Name: pointer.String(vmssSpec.Name), - Role: pointer.String(infrav1.Node), - Additional: s.Scope.AdditionalTags(), - }) - - vmss.Tags = converters.TagsToMap(tags) - return vmss, nil -} - -func (s *Service) getVirtualMachineScaleSetNetworkConfiguration(vmssSpec azure.ScaleSetSpec) *[]compute.VirtualMachineScaleSetNetworkConfiguration { - var backendAddressPools []compute.SubResource - if vmssSpec.PublicLBName != "" { - if vmssSpec.PublicLBAddressPoolName != "" { - backendAddressPools = append(backendAddressPools, - compute.SubResource{ - ID: pointer.String(azure.AddressPoolID(s.Scope.SubscriptionID(), s.Scope.ResourceGroup(), vmssSpec.PublicLBName, vmssSpec.PublicLBAddressPoolName)), - }) - } - } - nicConfigs := []compute.VirtualMachineScaleSetNetworkConfiguration{} - for i, n := range vmssSpec.NetworkInterfaces { - nicConfig := compute.VirtualMachineScaleSetNetworkConfiguration{} - nicConfig.VirtualMachineScaleSetNetworkConfigurationProperties = &compute.VirtualMachineScaleSetNetworkConfigurationProperties{} - nicConfig.Name = pointer.String(vmssSpec.Name + "-nic-" + strconv.Itoa(i)) - nicConfig.EnableIPForwarding = pointer.Bool(true) - if n.AcceleratedNetworking != nil { - nicConfig.VirtualMachineScaleSetNetworkConfigurationProperties.EnableAcceleratedNetworking = n.AcceleratedNetworking - } else { - // If AcceleratedNetworking is not specified, use the value from the VMSS spec. - // It will be set to true if the VMSS SKU supports it. - nicConfig.VirtualMachineScaleSetNetworkConfigurationProperties.EnableAcceleratedNetworking = vmssSpec.AcceleratedNetworking - } - - // Create IPConfigs - ipconfigs := []compute.VirtualMachineScaleSetIPConfiguration{} - for j := 0; j < n.PrivateIPConfigs; j++ { - ipconfig := compute.VirtualMachineScaleSetIPConfiguration{ - Name: pointer.String(fmt.Sprintf("ipConfig" + strconv.Itoa(j))), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ - PrivateIPAddressVersion: compute.IPVersionIPv4, - Subnet: &compute.APIEntityReference{ - ID: pointer.String(azure.SubnetID(s.Scope.SubscriptionID(), vmssSpec.VNetResourceGroup, vmssSpec.VNetName, n.SubnetName)), - }, - }, - } - - if j == 0 { - // Always use the first IPConfig as the Primary - ipconfig.Primary = pointer.Bool(true) - } - ipconfigs = append(ipconfigs, ipconfig) - } - if vmssSpec.IPv6Enabled { - ipv6Config := compute.VirtualMachineScaleSetIPConfiguration{ - Name: pointer.String("ipConfigv6"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ - PrivateIPAddressVersion: compute.IPVersionIPv6, - Primary: pointer.Bool(false), - Subnet: &compute.APIEntityReference{ - ID: pointer.String(azure.SubnetID(s.Scope.SubscriptionID(), vmssSpec.VNetResourceGroup, vmssSpec.VNetName, n.SubnetName)), - }, - }, - } - ipconfigs = append(ipconfigs, ipv6Config) - } - if i == 0 { - ipconfigs[0].LoadBalancerBackendAddressPools = &backendAddressPools - nicConfig.VirtualMachineScaleSetNetworkConfigurationProperties.Primary = pointer.Bool(true) - } - nicConfig.VirtualMachineScaleSetNetworkConfigurationProperties.IPConfigurations = &ipconfigs - nicConfigs = append(nicConfigs, nicConfig) - } - return &nicConfigs -} - // getVirtualMachineScaleSet provides information about a Virtual Machine Scale Set and its instances. -func (s *Service) getVirtualMachineScaleSet(ctx context.Context, vmssName string) (*azure.VMSS, error) { +func (s *Service) getVirtualMachineScaleSet(ctx context.Context, spec azure.ResourceSpecGetter) (*azure.VMSS, error) { ctx, _, done := tele.StartSpanWithLogger(ctx, "scalesets.Service.getVirtualMachineScaleSet") defer done() - vmss, err := s.Client.Get(ctx, s.Scope.ResourceGroup(), vmssName) + vmssResult, err := s.Client.Get(ctx, spec) if err != nil { return nil, errors.Wrap(err, "failed to get existing vmss") } - - vmssInstances, err := s.Client.ListInstances(ctx, s.Scope.ResourceGroup(), vmssName) - if err != nil { - return nil, errors.Wrap(err, "failed to list instances") + vmss, ok := vmssResult.(compute.VirtualMachineScaleSet) + if !ok { + return nil, errors.Errorf("%T is not a compute.VirtualMachineScaleSet", vmssResult) } - return converters.SDKToVMSS(vmss, vmssInstances), nil -} - -// getVirtualMachineScaleSetIfDone gets a Virtual Machine Scale Set and its instances from Azure if the future is completed. -func (s *Service) getVirtualMachineScaleSetIfDone(ctx context.Context, future *infrav1.Future) (*azure.VMSS, error) { - ctx, _, done := tele.StartSpanWithLogger(ctx, "scalesets.Service.getVirtualMachineScaleSetIfDone") - defer done() - - vmss, err := s.GetResultIfDone(ctx, future) - if err != nil { - return nil, errors.Wrap(err, "failed to get result from future") - } - - vmssInstances, err := s.Client.ListInstances(ctx, future.ResourceGroup, future.Name) + vmssInstances, err := s.Client.ListInstances(ctx, spec) if err != nil { return nil, errors.Wrap(err, "failed to list instances") } @@ -681,208 +310,23 @@ func (s *Service) getVirtualMachineScaleSetIfDone(ctx context.Context, future *i return converters.SDKToVMSS(vmss, vmssInstances), nil } -func (s *Service) generateExtensions(ctx context.Context) ([]compute.VirtualMachineScaleSetExtension, error) { - extensions := make([]compute.VirtualMachineScaleSetExtension, len(s.Scope.VMSSExtensionSpecs())) - for i, extensionSpec := range s.Scope.VMSSExtensionSpecs() { - extensionSpec := extensionSpec - parameters, err := extensionSpec.Parameters(ctx, nil) - if err != nil { - return nil, err - } - vmssextension, ok := parameters.(compute.VirtualMachineScaleSetExtension) - if !ok { - return nil, errors.Errorf("%T is not a compute.VirtualMachineScaleSetExtension", parameters) - } - extensions[i] = vmssextension - } - - return extensions, nil -} - -// generateStorageProfile generates a pointer to a compute.VirtualMachineScaleSetStorageProfile which can utilized for VM creation. -func (s *Service) generateStorageProfile(ctx context.Context, vmssSpec azure.ScaleSetSpec, sku resourceskus.SKU) (*compute.VirtualMachineScaleSetStorageProfile, error) { - ctx, _, done := tele.StartSpanWithLogger(ctx, "scalesets.Service.generateStorageProfile") - defer done() - - storageProfile := &compute.VirtualMachineScaleSetStorageProfile{ - OsDisk: &compute.VirtualMachineScaleSetOSDisk{ - OsType: compute.OperatingSystemTypes(vmssSpec.OSDisk.OSType), - CreateOption: compute.DiskCreateOptionTypesFromImage, - DiskSizeGB: vmssSpec.OSDisk.DiskSizeGB, - }, - } - - // enable ephemeral OS - if vmssSpec.OSDisk.DiffDiskSettings != nil { - if !sku.HasCapability(resourceskus.EphemeralOSDisk) { - return nil, fmt.Errorf("vm size %s does not support ephemeral os. select a different vm size or disable ephemeral os", vmssSpec.Size) - } - - storageProfile.OsDisk.DiffDiskSettings = &compute.DiffDiskSettings{ - Option: compute.DiffDiskOptions(vmssSpec.OSDisk.DiffDiskSettings.Option), - } - } - - if vmssSpec.OSDisk.ManagedDisk != nil { - storageProfile.OsDisk.ManagedDisk = &compute.VirtualMachineScaleSetManagedDiskParameters{} - if vmssSpec.OSDisk.ManagedDisk.StorageAccountType != "" { - storageProfile.OsDisk.ManagedDisk.StorageAccountType = compute.StorageAccountTypes(vmssSpec.OSDisk.ManagedDisk.StorageAccountType) - } - if vmssSpec.OSDisk.ManagedDisk.DiskEncryptionSet != nil { - storageProfile.OsDisk.ManagedDisk.DiskEncryptionSet = &compute.DiskEncryptionSetParameters{ID: pointer.String(vmssSpec.OSDisk.ManagedDisk.DiskEncryptionSet.ID)} - } - } - - if vmssSpec.OSDisk.CachingType != "" { - storageProfile.OsDisk.Caching = compute.CachingTypes(vmssSpec.OSDisk.CachingType) - } - - dataDisks := make([]compute.VirtualMachineScaleSetDataDisk, len(vmssSpec.DataDisks)) - for i, disk := range vmssSpec.DataDisks { - dataDisks[i] = compute.VirtualMachineScaleSetDataDisk{ - CreateOption: compute.DiskCreateOptionTypesEmpty, - DiskSizeGB: pointer.Int32(disk.DiskSizeGB), - Lun: disk.Lun, - Name: pointer.String(azure.GenerateDataDiskName(vmssSpec.Name, disk.NameSuffix)), - } - - if disk.ManagedDisk != nil { - dataDisks[i].ManagedDisk = &compute.VirtualMachineScaleSetManagedDiskParameters{ - StorageAccountType: compute.StorageAccountTypes(disk.ManagedDisk.StorageAccountType), - } - - if disk.ManagedDisk.DiskEncryptionSet != nil { - dataDisks[i].ManagedDisk.DiskEncryptionSet = &compute.DiskEncryptionSetParameters{ID: pointer.String(disk.ManagedDisk.DiskEncryptionSet.ID)} - } - } - } - storageProfile.DataDisks = &dataDisks - - image, err := s.Scope.GetVMImage(ctx) - if err != nil { - return nil, errors.Wrap(err, "failed to get VM image") - } - - s.Scope.SaveVMImageToStatus(image) - - imageRef, err := converters.ImageToSDK(image) - if err != nil { - return nil, err - } - - storageProfile.ImageReference = imageRef - - return storageProfile, nil -} - -func (s *Service) generateOSProfile(ctx context.Context, vmssSpec azure.ScaleSetSpec) (*compute.VirtualMachineScaleSetOSProfile, error) { - sshKey, err := base64.StdEncoding.DecodeString(vmssSpec.SSHKeyData) - if err != nil { - return nil, errors.Wrap(err, "failed to decode ssh public key") - } - bootstrapData, err := s.Scope.GetBootstrapData(ctx) - if err != nil { - return nil, errors.Wrap(err, "failed to retrieve bootstrap data") - } - - osProfile := &compute.VirtualMachineScaleSetOSProfile{ - ComputerNamePrefix: pointer.String(vmssSpec.Name), - AdminUsername: pointer.String(azure.DefaultUserName), - CustomData: pointer.String(bootstrapData), - } - - switch vmssSpec.OSDisk.OSType { - case string(compute.OperatingSystemTypesWindows): - // Cloudbase-init is used to generate a password. - // https://cloudbase-init.readthedocs.io/en/latest/plugins.html#setting-password-main - // - // We generate a random password here in case of failure - // but the password on the VM will NOT be the same as created here. - // Access is provided via SSH public key that is set during deployment - // Azure also provides a way to reset user passwords in the case of need. - osProfile.AdminPassword = pointer.String(generators.SudoRandomPassword(123)) - osProfile.WindowsConfiguration = &compute.WindowsConfiguration{ - EnableAutomaticUpdates: pointer.Bool(false), - } - default: - osProfile.LinuxConfiguration = &compute.LinuxConfiguration{ - DisablePasswordAuthentication: pointer.Bool(true), - SSH: &compute.SSHConfiguration{ - PublicKeys: &[]compute.SSHPublicKey{ - { - Path: pointer.String(fmt.Sprintf("/home/%s/.ssh/authorized_keys", azure.DefaultUserName)), - KeyData: pointer.String(string(sshKey)), - }, - }, - }, - } - } - - return osProfile, nil -} - -func (s *Service) generateImagePlan(ctx context.Context) *compute.Plan { - ctx, log, done := tele.StartSpanWithLogger(ctx, "scalesets.Service.generateImagePlan") - defer done() - - image, err := s.Scope.GetVMImage(ctx) - if err != nil { - log.Error(err, "failed to get vm image, disabling Plan") - return nil - } - - if image.SharedGallery != nil && image.SharedGallery.Publisher != nil && image.SharedGallery.SKU != nil && image.SharedGallery.Offer != nil { - return &compute.Plan{ - Publisher: image.SharedGallery.Publisher, - Name: image.SharedGallery.SKU, - Product: image.SharedGallery.Offer, - } - } - - if image.Marketplace == nil || !image.Marketplace.ThirdPartyImage { - return nil - } - - if image.Marketplace.Publisher == "" || image.Marketplace.SKU == "" || image.Marketplace.Offer == "" { - return nil - } - - return &compute.Plan{ - Publisher: pointer.String(image.Marketplace.Publisher), - Name: pointer.String(image.Marketplace.SKU), - Product: pointer.String(image.Marketplace.Offer), - } -} - -func getVMSSUpdateFromVMSS(vmss compute.VirtualMachineScaleSet) (compute.VirtualMachineScaleSetUpdate, error) { - jsonData, err := vmss.MarshalJSON() - if err != nil { - return compute.VirtualMachineScaleSetUpdate{}, err - } - - var update compute.VirtualMachineScaleSetUpdate - if err := update.UnmarshalJSON(jsonData); err != nil { - return update, err - } - - // wipe out network profile, so updates won't conflict with Cloud Provider updates - update.VirtualMachineProfile.NetworkProfile = nil - return update, nil -} - -func getSecurityProfile(vmssSpec azure.ScaleSetSpec, sku resourceskus.SKU) (*compute.SecurityProfile, error) { - if vmssSpec.SecurityProfile == nil { - return nil, nil - } - - if !sku.HasCapability(resourceskus.EncryptionAtHost) { - return nil, azure.WithTerminalError(errors.Errorf("encryption at host is not supported for VM type %s", vmssSpec.Size)) - } - - return &compute.SecurityProfile{ - EncryptionAtHost: pointer.Bool(*vmssSpec.SecurityProfile.EncryptionAtHost), - }, nil -} +// getVirtualMachineScaleSetIfDone gets a Virtual Machine Scale Set and its instances from Azure if the future is completed. +// func (s *Service) getVirtualMachineScaleSetIfDone(ctx context.Context, future *infrav1.Future) (*azure.VMSS, error) { +// ctx, _, done := tele.StartSpanWithLogger(ctx, "scalesets.Service.getVirtualMachineScaleSetIfDone") +// defer done() + +// vmss, err := s.GetResultIfDone(ctx, future) +// if err != nil { +// return nil, errors.Wrap(err, "failed to get result from future") +// } + +// vmssInstances, err := s.Client.ListInstances(ctx, future.ResourceGroup, future.Name) +// if err != nil { +// return nil, errors.Wrap(err, "failed to list instances") +// } + +// return converters.SDKToVMSS(vmss, vmssInstances), nil +// } // IsManaged returns always returns true as CAPZ does not support BYO scale set. func (s *Service) IsManaged(ctx context.Context) (bool, error) { diff --git a/azure/services/scalesets/spec.go b/azure/services/scalesets/spec.go new file mode 100644 index 000000000000..4160d0a35a16 --- /dev/null +++ b/azure/services/scalesets/spec.go @@ -0,0 +1,545 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scalesets + +import ( + "context" + "encoding/base64" + "fmt" + "strconv" + + "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2021-11-01/compute" + "github.com/Azure/go-autorest/autorest/to" + "github.com/pkg/errors" + "k8s.io/utils/pointer" + infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" + "sigs.k8s.io/cluster-api-provider-azure/azure" + "sigs.k8s.io/cluster-api-provider-azure/azure/converters" + "sigs.k8s.io/cluster-api-provider-azure/azure/services/resourceskus" + "sigs.k8s.io/cluster-api-provider-azure/util/generators" + "sigs.k8s.io/cluster-api-provider-azure/util/tele" +) + +// ScaleSetSpec defines the specification for a Scale Set. +type ScaleSetSpec struct { + Name string + ResourceGroup string + Size string + Capacity int64 + SSHKeyData string + OSDisk infrav1.OSDisk + DataDisks []infrav1.DataDisk + SubnetName string + VNetName string + VNetResourceGroup string + PublicLBName string + PublicLBAddressPoolName string + AcceleratedNetworking *bool + TerminateNotificationTimeout *int + Identity infrav1.VMIdentity + UserAssignedIdentities []infrav1.UserAssignedIdentity + SecurityProfile *infrav1.SecurityProfile + SpotVMOptions *infrav1.SpotVMOptions + AdditionalCapabilities *infrav1.AdditionalCapabilities + DiagnosticsProfile *infrav1.Diagnostics + FailureDomains []string + VMExtensions []infrav1.VMExtension + NetworkInterfaces []infrav1.NetworkInterface + IPv6Enabled bool + OrchestrationMode infrav1.OrchestrationModeType + Location string + SubscriptionID string + SKU resourceskus.SKU + VMSSExtensionSpecs []azure.ResourceSpecGetter + VMImage *infrav1.Image + BootstrapData string + VMSSInstances []compute.VirtualMachineScaleSetVM + MaxSurge int + ClusterName string + ShouldPatchCustomData bool + AdditionalTags infrav1.Tags +} + +// ResourceName returns the name of the subnet. +func (s *ScaleSetSpec) ResourceName() string { + return s.Name +} + +// ResourceGroupName returns the name of the resource group of the VNet that owns this subnet. +func (s *ScaleSetSpec) ResourceGroupName() string { + return s.ResourceGroup +} + +// OwnerResourceName returns the name of the VNet that owns this subnet. +func (s *ScaleSetSpec) OwnerResourceName() string { + return "" +} + +// Parameters returns the parameters for the subnet. +func (s *ScaleSetSpec) Parameters(ctx context.Context, existing interface{}) (parameters interface{}, err error) { + if existing != nil { + existingScaleSet, ok := existing.(compute.VirtualMachineScaleSet) + if !ok { + return nil, errors.Errorf("%T is not a compute.VirtualMachineScaleSet", existing) + } + + infraVMSS := converters.SDKToVMSS(existingScaleSet, s.VMSSInstances) + + parameters, err := s.Parameters(ctx, nil) + if err != nil { + return nil, errors.Wrapf(err, "failed to generate scale set update parameters for %s", s.Name) + } + + vmss, ok := parameters.(compute.VirtualMachineScaleSet) + if !ok { + return nil, errors.Errorf("%T is not a compute.VirtualMachineScaleSet", existing) + } + + patch, err := getVMSSUpdateFromVMSS(vmss) + if err != nil { + return nil, errors.Wrapf(err, "failed to generate vmss patch for %s", s.Name) + } + + hasModelChanges := hasModelModifyingDifferences(infraVMSS, vmss) + isFlex := s.OrchestrationMode == infrav1.FlexibleOrchestrationMode + updated := true + if !isFlex { + updated = infraVMSS.HasEnoughLatestModelOrNotMixedModel() + } + if s.MaxSurge > 0 && (hasModelChanges || !updated) { + // surge capacity with the intention of lowering during instance reconciliation + surge := s.Capacity + int64(s.MaxSurge) + // log.V(4).Info("surging...", "surge", surge, "hasModelChanges", hasModelChanges, "updated", updated) + patch.Sku.Capacity = pointer.Int64(surge) + } + + // If there are no model changes and no increase in the replica count, do not update the VMSS. + // Decreases in replica count is handled by deleting AzureMachinePoolMachine instances in the MachinePoolScope + if *patch.Sku.Capacity <= infraVMSS.Capacity && !hasModelChanges && !s.ShouldPatchCustomData { + // log.V(4).Info("nothing to update on vmss", "scale set", s.Name, "newReplicas", *patch.Sku.Capacity, "oldReplicas", infraVMSS.Capacity, "hasChanges", hasModelChanges) + + // up to date, nothing to do + return nil, nil + } + + return patch, nil + } + // sku, err := s.resourceSKUCache.Get(ctx, s.Size, resourceskus.VirtualMachines) + // if err != nil { + // return compute.VirtualMachineScaleSet{}, errors.Wrapf(err, "failed to get find SKU %s in compute api", s.Size) + // } + + if s.AcceleratedNetworking == nil { + // set accelerated networking to the capability of the VMSize + accelNet := s.SKU.HasCapability(resourceskus.AcceleratedNetworking) + s.AcceleratedNetworking = &accelNet + } + + extensions, err := s.generateExtensions(ctx) + if err != nil { + return compute.VirtualMachineScaleSet{}, err + } + + storageProfile, err := s.generateStorageProfile(ctx) + if err != nil { + return compute.VirtualMachineScaleSet{}, err + } + + securityProfile, err := s.getSecurityProfile() + if err != nil { + return compute.VirtualMachineScaleSet{}, err + } + + priority, evictionPolicy, billingProfile, err := converters.GetSpotVMOptions(s.SpotVMOptions, s.OSDisk.DiffDiskSettings) + if err != nil { + return compute.VirtualMachineScaleSet{}, errors.Wrapf(err, "failed to get Spot VM options") + } + + diagnosticsProfile := converters.GetDiagnosticsProfile(s.DiagnosticsProfile) + + osProfile, err := s.generateOSProfile(ctx) + if err != nil { + return compute.VirtualMachineScaleSet{}, err + } + + orchestrationMode := converters.GetOrchestrationMode(s.OrchestrationMode) + + vmss := compute.VirtualMachineScaleSet{ + Location: pointer.String(s.Location), + Sku: &compute.Sku{ + Name: pointer.String(s.Size), + Tier: pointer.String("Standard"), + Capacity: pointer.Int64(s.Capacity), + }, + Zones: to.StringSlicePtr(s.FailureDomains), + Plan: s.generateImagePlan(ctx), + VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ + OrchestrationMode: orchestrationMode, + SinglePlacementGroup: pointer.Bool(false), + VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{ + OsProfile: osProfile, + StorageProfile: storageProfile, + SecurityProfile: securityProfile, + DiagnosticsProfile: diagnosticsProfile, + NetworkProfile: &compute.VirtualMachineScaleSetNetworkProfile{ + NetworkInterfaceConfigurations: s.getVirtualMachineScaleSetNetworkConfiguration(), + }, + Priority: priority, + EvictionPolicy: evictionPolicy, + BillingProfile: billingProfile, + ExtensionProfile: &compute.VirtualMachineScaleSetExtensionProfile{ + Extensions: &extensions, + }, + }, + }, + } + + // Set properties specific to VMSS orchestration mode + switch orchestrationMode { + case compute.OrchestrationModeUniform: + vmss.VirtualMachineScaleSetProperties.Overprovision = pointer.Bool(false) + vmss.VirtualMachineScaleSetProperties.UpgradePolicy = &compute.UpgradePolicy{Mode: compute.UpgradeModeManual} + case compute.OrchestrationModeFlexible: + vmss.VirtualMachineScaleSetProperties.VirtualMachineProfile.NetworkProfile.NetworkAPIVersion = + compute.NetworkAPIVersionTwoZeroTwoZeroHyphenMinusOneOneHyphenMinusZeroOne + vmss.VirtualMachineScaleSetProperties.PlatformFaultDomainCount = to.Int32Ptr(1) + if len(s.FailureDomains) > 1 { + vmss.VirtualMachineScaleSetProperties.PlatformFaultDomainCount = to.Int32Ptr(int32(len(s.FailureDomains))) + } + } + + // Assign Identity to VMSS + if s.Identity == infrav1.VMIdentitySystemAssigned { + vmss.Identity = &compute.VirtualMachineScaleSetIdentity{ + Type: compute.ResourceIdentityTypeSystemAssigned, + } + } else if s.Identity == infrav1.VMIdentityUserAssigned { + userIdentitiesMap, err := converters.UserAssignedIdentitiesToVMSSSDK(s.UserAssignedIdentities) + if err != nil { + return vmss, errors.Wrapf(err, "failed to assign identity %q", s.Name) + } + vmss.Identity = &compute.VirtualMachineScaleSetIdentity{ + Type: compute.ResourceIdentityTypeUserAssigned, + UserAssignedIdentities: userIdentitiesMap, + } + } + + // Provisionally detect whether there is any Data Disk defined which uses UltraSSDs. + // If that's the case, enable the UltraSSD capability. + for _, dataDisk := range s.DataDisks { + if dataDisk.ManagedDisk != nil && dataDisk.ManagedDisk.StorageAccountType == string(compute.StorageAccountTypesUltraSSDLRS) { + vmss.VirtualMachineScaleSetProperties.AdditionalCapabilities = &compute.AdditionalCapabilities{ + UltraSSDEnabled: pointer.Bool(true), + } + } + } + + // Set Additional Capabilities if any is present on the spec. + if s.AdditionalCapabilities != nil { + // Set UltraSSDEnabled if a specific value is set on the spec for it. + if s.AdditionalCapabilities.UltraSSDEnabled != nil { + vmss.AdditionalCapabilities.UltraSSDEnabled = s.AdditionalCapabilities.UltraSSDEnabled + } + } + + if s.TerminateNotificationTimeout != nil { + vmss.VirtualMachineScaleSetProperties.VirtualMachineProfile.ScheduledEventsProfile = &compute.ScheduledEventsProfile{ + TerminateNotificationProfile: &compute.TerminateNotificationProfile{ + NotBeforeTimeout: pointer.String(fmt.Sprintf("PT%dM", *s.TerminateNotificationTimeout)), + Enable: pointer.Bool(true), + }, + } + } + + tags := infrav1.Build(infrav1.BuildParams{ + ClusterName: s.ClusterName, + Lifecycle: infrav1.ResourceLifecycleOwned, + Name: pointer.String(s.Name), + Role: pointer.String(infrav1.Node), + Additional: s.AdditionalTags, + }) + + vmss.Tags = converters.TagsToMap(tags) + return vmss, nil + + return vmss, nil +} + +func hasModelModifyingDifferences(infraVMSS *azure.VMSS, vmss compute.VirtualMachineScaleSet) bool { + other := converters.SDKToVMSS(vmss, []compute.VirtualMachineScaleSetVM{}) + return infraVMSS.HasModelChanges(*other) +} + +func (s *ScaleSetSpec) generateExtensions(ctx context.Context) ([]compute.VirtualMachineScaleSetExtension, error) { + extensions := make([]compute.VirtualMachineScaleSetExtension, len(s.VMSSExtensionSpecs)) + for i, extensionSpec := range s.VMSSExtensionSpecs { + extensionSpec := extensionSpec + parameters, err := extensionSpec.Parameters(ctx, nil) + if err != nil { + return nil, err + } + vmssextension, ok := parameters.(compute.VirtualMachineScaleSetExtension) + if !ok { + return nil, errors.Errorf("%T is not a compute.VirtualMachineScaleSetExtension", parameters) + } + extensions[i] = vmssextension + } + + return extensions, nil +} + +func (s *ScaleSetSpec) getVirtualMachineScaleSetNetworkConfiguration() *[]compute.VirtualMachineScaleSetNetworkConfiguration { + var backendAddressPools []compute.SubResource + if s.PublicLBName != "" { + if s.PublicLBAddressPoolName != "" { + backendAddressPools = append(backendAddressPools, + compute.SubResource{ + ID: pointer.String(azure.AddressPoolID(s.SubscriptionID, s.ResourceGroup, s.PublicLBName, s.PublicLBAddressPoolName)), + }) + } + } + nicConfigs := []compute.VirtualMachineScaleSetNetworkConfiguration{} + for i, n := range s.NetworkInterfaces { + nicConfig := compute.VirtualMachineScaleSetNetworkConfiguration{} + nicConfig.VirtualMachineScaleSetNetworkConfigurationProperties = &compute.VirtualMachineScaleSetNetworkConfigurationProperties{} + nicConfig.Name = pointer.String(s.Name + "-nic-" + strconv.Itoa(i)) + nicConfig.EnableIPForwarding = pointer.Bool(true) + if n.AcceleratedNetworking != nil { + nicConfig.VirtualMachineScaleSetNetworkConfigurationProperties.EnableAcceleratedNetworking = n.AcceleratedNetworking + } else { + // If AcceleratedNetworking is not specified, use the value from the VMSS spec. + // It will be set to true if the VMSS SKU supports it. + nicConfig.VirtualMachineScaleSetNetworkConfigurationProperties.EnableAcceleratedNetworking = s.AcceleratedNetworking + } + + // Create IPConfigs + ipconfigs := []compute.VirtualMachineScaleSetIPConfiguration{} + for j := 0; j < n.PrivateIPConfigs; j++ { + ipconfig := compute.VirtualMachineScaleSetIPConfiguration{ + Name: pointer.String(fmt.Sprintf("ipConfig" + strconv.Itoa(j))), + VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + PrivateIPAddressVersion: compute.IPVersionIPv4, + Subnet: &compute.APIEntityReference{ + ID: pointer.String(azure.SubnetID(s.SubscriptionID, s.VNetResourceGroup, s.VNetName, n.SubnetName)), + }, + }, + } + + if j == 0 { + // Always use the first IPConfig as the Primary + ipconfig.Primary = pointer.Bool(true) + } + ipconfigs = append(ipconfigs, ipconfig) + } + if s.IPv6Enabled { + ipv6Config := compute.VirtualMachineScaleSetIPConfiguration{ + Name: pointer.String("ipConfigv6"), + VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + PrivateIPAddressVersion: compute.IPVersionIPv6, + Primary: pointer.Bool(false), + Subnet: &compute.APIEntityReference{ + ID: pointer.String(azure.SubnetID(s.SubscriptionID, s.VNetResourceGroup, s.VNetName, n.SubnetName)), + }, + }, + } + ipconfigs = append(ipconfigs, ipv6Config) + } + if i == 0 { + ipconfigs[0].LoadBalancerBackendAddressPools = &backendAddressPools + nicConfig.VirtualMachineScaleSetNetworkConfigurationProperties.Primary = pointer.Bool(true) + } + nicConfig.VirtualMachineScaleSetNetworkConfigurationProperties.IPConfigurations = &ipconfigs + nicConfigs = append(nicConfigs, nicConfig) + } + return &nicConfigs +} + +// generateStorageProfile generates a pointer to a compute.VirtualMachineScaleSetStorageProfile which can utilized for VM creation. +func (s *ScaleSetSpec) generateStorageProfile(ctx context.Context) (*compute.VirtualMachineScaleSetStorageProfile, error) { + ctx, _, done := tele.StartSpanWithLogger(ctx, "scalesets.ScaleSetSpec.generateStorageProfile") + defer done() + + storageProfile := &compute.VirtualMachineScaleSetStorageProfile{ + OsDisk: &compute.VirtualMachineScaleSetOSDisk{ + OsType: compute.OperatingSystemTypes(s.OSDisk.OSType), + CreateOption: compute.DiskCreateOptionTypesFromImage, + DiskSizeGB: s.OSDisk.DiskSizeGB, + }, + } + + // enable ephemeral OS + if s.OSDisk.DiffDiskSettings != nil { + if !s.SKU.HasCapability(resourceskus.EphemeralOSDisk) { + return nil, fmt.Errorf("vm size %s does not support ephemeral os. select a different vm size or disable ephemeral os", s.Size) + } + + storageProfile.OsDisk.DiffDiskSettings = &compute.DiffDiskSettings{ + Option: compute.DiffDiskOptions(s.OSDisk.DiffDiskSettings.Option), + } + } + + if s.OSDisk.ManagedDisk != nil { + storageProfile.OsDisk.ManagedDisk = &compute.VirtualMachineScaleSetManagedDiskParameters{} + if s.OSDisk.ManagedDisk.StorageAccountType != "" { + storageProfile.OsDisk.ManagedDisk.StorageAccountType = compute.StorageAccountTypes(s.OSDisk.ManagedDisk.StorageAccountType) + } + if s.OSDisk.ManagedDisk.DiskEncryptionSet != nil { + storageProfile.OsDisk.ManagedDisk.DiskEncryptionSet = &compute.DiskEncryptionSetParameters{ID: pointer.String(s.OSDisk.ManagedDisk.DiskEncryptionSet.ID)} + } + } + + if s.OSDisk.CachingType != "" { + storageProfile.OsDisk.Caching = compute.CachingTypes(s.OSDisk.CachingType) + } + + dataDisks := make([]compute.VirtualMachineScaleSetDataDisk, len(s.DataDisks)) + for i, disk := range s.DataDisks { + dataDisks[i] = compute.VirtualMachineScaleSetDataDisk{ + CreateOption: compute.DiskCreateOptionTypesEmpty, + DiskSizeGB: to.Int32Ptr(disk.DiskSizeGB), + Lun: disk.Lun, + Name: pointer.String(azure.GenerateDataDiskName(s.Name, disk.NameSuffix)), + } + + if disk.ManagedDisk != nil { + dataDisks[i].ManagedDisk = &compute.VirtualMachineScaleSetManagedDiskParameters{ + StorageAccountType: compute.StorageAccountTypes(disk.ManagedDisk.StorageAccountType), + } + + if disk.ManagedDisk.DiskEncryptionSet != nil { + dataDisks[i].ManagedDisk.DiskEncryptionSet = &compute.DiskEncryptionSetParameters{ID: pointer.String(disk.ManagedDisk.DiskEncryptionSet.ID)} + } + } + } + storageProfile.DataDisks = &dataDisks + + // TODO: Find a way to save this to status, likely after CreateResource() is called. + // s.Scope.SaveVMImageToStatus(s.VMImage) + + imageRef, err := converters.ImageToSDK(s.VMImage) + if err != nil { + return nil, err + } + + storageProfile.ImageReference = imageRef + + return storageProfile, nil +} + +func (s *ScaleSetSpec) generateOSProfile(ctx context.Context) (*compute.VirtualMachineScaleSetOSProfile, error) { + sshKey, err := base64.StdEncoding.DecodeString(s.SSHKeyData) + if err != nil { + return nil, errors.Wrap(err, "failed to decode ssh public key") + } + + osProfile := &compute.VirtualMachineScaleSetOSProfile{ + ComputerNamePrefix: pointer.String(s.Name), + AdminUsername: pointer.String(azure.DefaultUserName), + CustomData: pointer.String(s.BootstrapData), + } + + switch s.OSDisk.OSType { + case string(compute.OperatingSystemTypesWindows): + // Cloudbase-init is used to generate a password. + // https://cloudbase-init.readthedocs.io/en/latest/plugins.html#setting-password-main + // + // We generate a random password here in case of failure + // but the password on the VM will NOT be the same as created here. + // Access is provided via SSH public key that is set during deployment + // Azure also provides a way to reset user passwords in the case of need. + osProfile.AdminPassword = pointer.String(generators.SudoRandomPassword(123)) + osProfile.WindowsConfiguration = &compute.WindowsConfiguration{ + EnableAutomaticUpdates: pointer.Bool(false), + } + default: + osProfile.LinuxConfiguration = &compute.LinuxConfiguration{ + DisablePasswordAuthentication: pointer.Bool(true), + SSH: &compute.SSHConfiguration{ + PublicKeys: &[]compute.SSHPublicKey{ + { + Path: pointer.String(fmt.Sprintf("/home/%s/.ssh/authorized_keys", azure.DefaultUserName)), + KeyData: pointer.String(string(sshKey)), + }, + }, + }, + } + } + + return osProfile, nil +} + +func (s *ScaleSetSpec) generateImagePlan(ctx context.Context) *compute.Plan { + ctx, log, done := tele.StartSpanWithLogger(ctx, "scalesets.ScaleSetSpec.generateImagePlan") + defer done() + + if s.VMImage == nil { + log.V(2).Info("no vm image found, disabling plan") + return nil + } + + if s.VMImage.SharedGallery != nil && s.VMImage.SharedGallery.Publisher != nil && s.VMImage.SharedGallery.SKU != nil && s.VMImage.SharedGallery.Offer != nil { + return &compute.Plan{ + Publisher: s.VMImage.SharedGallery.Publisher, + Name: s.VMImage.SharedGallery.SKU, + Product: s.VMImage.SharedGallery.Offer, + } + } + + if s.VMImage.Marketplace == nil || !s.VMImage.Marketplace.ThirdPartyImage { + return nil + } + + if s.VMImage.Marketplace.Publisher == "" || s.VMImage.Marketplace.SKU == "" || s.VMImage.Marketplace.Offer == "" { + return nil + } + + return &compute.Plan{ + Publisher: pointer.String(s.VMImage.Marketplace.Publisher), + Name: pointer.String(s.VMImage.Marketplace.SKU), + Product: pointer.String(s.VMImage.Marketplace.Offer), + } +} + +func (s *ScaleSetSpec) getSecurityProfile() (*compute.SecurityProfile, error) { + if s.SecurityProfile == nil { + return nil, nil + } + + if !s.SKU.HasCapability(resourceskus.EncryptionAtHost) { + return nil, azure.WithTerminalError(errors.Errorf("encryption at host is not supported for VM type %s", s.Size)) + } + + return &compute.SecurityProfile{ + EncryptionAtHost: pointer.Bool(*s.SecurityProfile.EncryptionAtHost), + }, nil +} + +func getVMSSUpdateFromVMSS(vmss compute.VirtualMachineScaleSet) (compute.VirtualMachineScaleSetUpdate, error) { + jsonData, err := vmss.MarshalJSON() + if err != nil { + return compute.VirtualMachineScaleSetUpdate{}, err + } + + var update compute.VirtualMachineScaleSetUpdate + if err := update.UnmarshalJSON(jsonData); err != nil { + return update, err + } + + // wipe out network profile, so updates won't conflict with Cloud Provider updates + update.VirtualMachineProfile.NetworkProfile = nil + return update, nil +}