Skip to content

Commit

Permalink
Merge pull request #6956 from Bryce-Soghigian/bsoghigian/azure/has-in…
Browse files Browse the repository at this point in the history
…stance-impl

feat: Azure Provider HasInstance implementation
  • Loading branch information
k8s-ci-robot authored Jul 31, 2024
2 parents 392cef8 + 34a26ee commit 1ae8bcc
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 63 deletions.
51 changes: 38 additions & 13 deletions cluster-autoscaler/cloudprovider/azure/azure_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,26 +207,32 @@ func (m *azureCache) regenerate() error {
return nil
}

// fetchAzureResources retrieves and updates the cached Azure resources.
//
// This function performs the following:
// - Fetches and updates the list of Virtual Machine Scale Sets (VMSS) in the specified resource group.
// - Fetches and updates the list of Virtual Machines (VMs) and identifies the node pools they belong to.
// - Maintains a set of VMs pools and VMSS resources which helps the Cluster Autoscaler (CAS) operate on mixed node pools.
//
// Returns an error if any of the Azure API calls fail.
func (m *azureCache) fetchAzureResources() error {
m.mutex.Lock()
defer m.mutex.Unlock()

// fetch all the resources since CAS may be operating on mixed nodepools
// including both VMSS and VMs pools
// NOTE: this lists virtual machine scale sets, not virtual machine
// scale set instances
vmssResult, err := m.fetchScaleSets()
if err == nil {
m.scaleSets = vmssResult
} else {
if err != nil {
return err
}

m.scaleSets = vmssResult
vmResult, vmsPoolSet, err := m.fetchVirtualMachines()
if err == nil {
m.virtualMachines = vmResult
m.vmsPoolSet = vmsPoolSet
} else {
if err != nil {
return err
}
// we fetch both sets of resources since CAS may operate on mixed nodepools
m.virtualMachines = vmResult
m.vmsPoolSet = vmsPoolSet

return nil
}
Expand Down Expand Up @@ -275,8 +281,8 @@ func (m *azureCache) fetchVirtualMachines() (map[string][]compute.VirtualMachine
}

// nodes from vms pool will have tag "aks-managed-agentpool-type" set to "VirtualMachines"
if agnetpoolType := tags[agentpoolTypeTag]; agnetpoolType != nil {
if strings.EqualFold(to.String(agnetpoolType), vmsPoolType) {
if agentpoolType := tags[agentpoolTypeTag]; agentpoolType != nil {
if strings.EqualFold(to.String(agentpoolType), vmsPoolType) {
vmsPoolSet[to.String(vmPoolName)] = struct{}{}
}
}
Expand Down Expand Up @@ -313,7 +319,6 @@ func (m *azureCache) Register(nodeGroup cloudprovider.NodeGroup) bool {
// Node group is already registered and min/max size haven't changed, no action required.
return false
}

m.registeredNodeGroups[i] = nodeGroup
klog.V(4).Infof("Node group %q updated", nodeGroup.Id())
m.invalidateUnownedInstanceCache()
Expand All @@ -322,6 +327,7 @@ func (m *azureCache) Register(nodeGroup cloudprovider.NodeGroup) bool {
}

klog.V(4).Infof("Registering Node Group %q", nodeGroup.Id())

m.registeredNodeGroups = append(m.registeredNodeGroups, nodeGroup)
m.invalidateUnownedInstanceCache()
return true
Expand Down Expand Up @@ -390,6 +396,25 @@ func (m *azureCache) getAutoscalingOptions(ref azureRef) map[string]string {
return m.autoscalingOptions[ref]
}

// HasInstance returns if a given instance exists in the azure cache
func (m *azureCache) HasInstance(providerID string) (bool, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
resourceID, err := convertResourceGroupNameToLower(providerID)
if err != nil {
// Most likely an invalid resource id, we should return an error
// most of these shouldn't make it here do to higher level
// validation in the HasInstance azure.cloudprovider function
return false, err
}

if m.getInstanceFromCache(resourceID) != nil {
return true, nil
}
// couldn't find instance in the cache, assume it's deleted
return false, cloudprovider.ErrNotImplemented
}

// FindForInstance returns node group of the given Instance
func (m *azureCache) FindForInstance(instance *azureRef, vmType string) (cloudprovider.NodeGroup, error) {
vmsPoolSet := m.getVMsPoolSet()
Expand Down
25 changes: 22 additions & 3 deletions cluster-autoscaler/cloudprovider/azure/azure_cloud_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package azure

import (
"fmt"
"io"
"os"
"strings"
Expand Down Expand Up @@ -122,9 +123,27 @@ func (azure *AzureCloudProvider) NodeGroupForNode(node *apiv1.Node) (cloudprovid
return azure.azureManager.GetNodeGroupForInstance(ref)
}

// HasInstance returns whether a given node has a corresponding instance in this cloud provider
func (azure *AzureCloudProvider) HasInstance(*apiv1.Node) (bool, error) {
return true, cloudprovider.ErrNotImplemented
// HasInstance returns whether a given node has a corresponding instance in this cloud provider.
//
// Used to prevent undercount of existing VMs (taint-based overcount of deleted VMs),
// and so should not return false, nil (no instance) if uncertain; return error instead.
// (Think "has instance for sure, else error".) Returning an error causes fallback to taint-based
// determination; use ErrNotImplemented for silent fallback, any other error will be logged.
//
// Expected behavior (should work for VMSS Uniform/Flex, and VMs):
// - exists : return true, nil
// - !exists : return *, ErrNotImplemented (could use custom error for autoscaled nodes)
// - unimplemented case : return *, ErrNotImplemented
// - any other error : return *, error
func (azure *AzureCloudProvider) HasInstance(node *apiv1.Node) (bool, error) {
if node.Spec.ProviderID == "" {
return false, fmt.Errorf("ProviderID for node: %s is empty, skipped", node.Name)
}

if !strings.HasPrefix(node.Spec.ProviderID, "azure://") {
return false, fmt.Errorf("invalid azure ProviderID prefix for node: %s, skipped", node.Name)
}
return azure.azureManager.azureCache.HasInstance(node.Spec.ProviderID)
}

// Pricing returns pricing model for this cloud provider or error if not available.
Expand Down
224 changes: 178 additions & 46 deletions cluster-autoscaler/cloudprovider/azure/azure_cloud_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@ limitations under the License.
package azure

import (
"fmt"
"testing"

"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute"
"github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2017-05-10/resources"
"github.com/Azure/go-autorest/autorest/to"

apiv1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"k8s.io/autoscaler/cluster-autoscaler/cloudprovider"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient"
"sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient/mockvmssclient"
Expand Down Expand Up @@ -131,6 +134,126 @@ func TestNodeGroups(t *testing.T) {
assert.Equal(t, len(provider.NodeGroups()), 2)
}

func TestHasInstance(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

provider := newTestProvider(t)
mockVMSSClient := mockvmssclient.NewMockInterface(ctrl)
mockVMClient := mockvmclient.NewMockInterface(ctrl)
mockVMSSVMClient := mockvmssvmclient.NewMockInterface(ctrl)
provider.azureManager.azClient.virtualMachinesClient = mockVMClient
provider.azureManager.azClient.virtualMachineScaleSetsClient = mockVMSSClient
provider.azureManager.azClient.virtualMachineScaleSetVMsClient = mockVMSSVMClient

// Simulate node groups and instances
expectedScaleSets := newTestVMSSList(3, "test-asg", "eastus", compute.Uniform)
expectedVMsPoolVMs := newTestVMsPoolVMList(3)
expectedVMSSVMs := newTestVMSSVMList(3)

mockVMSSClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup).Return(expectedScaleSets, nil).AnyTimes()
mockVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup).Return(expectedVMsPoolVMs, nil).AnyTimes()
mockVMSSVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup, "test-asg", gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes()

// Register node groups
assert.Equal(t, len(provider.NodeGroups()), 0)
registered := provider.azureManager.RegisterNodeGroup(
newTestScaleSet(provider.azureManager, "test-asg"),
)
provider.azureManager.explicitlyConfigured["test-asg"] = true
assert.True(t, registered)

registered = provider.azureManager.RegisterNodeGroup(
newTestVMsPool(provider.azureManager, "test-vms-pool"),
)
provider.azureManager.explicitlyConfigured["test-vms-pool"] = true
assert.True(t, registered)
assert.Equal(t, len(provider.NodeGroups()), 2)

// Refresh cache
provider.azureManager.forceRefresh()

// Test HasInstance for a node from the VMSS pool
node := newApiNode(compute.Uniform, 0)
hasInstance, err := provider.azureManager.azureCache.HasInstance(node.Spec.ProviderID)
assert.True(t, hasInstance)
assert.NoError(t, err)

// Test HasInstance for a node from the VMs pool
vmsPoolNode := newVMsNode(0)
hasInstance, err = provider.azureManager.azureCache.HasInstance(vmsPoolNode.Spec.ProviderID)
assert.True(t, hasInstance)
assert.NoError(t, err)
}

func TestUnownedInstancesFallbackToDeletionTaint(t *testing.T) {
// VMSS Instances that belong to a VMSS on the cluster but do not belong to a registered ASG
// should return err unimplemented for HasInstance
ctrl := gomock.NewController(t)
defer ctrl.Finish()
provider := newTestProvider(t)
mockVMSSClient := mockvmssclient.NewMockInterface(ctrl)
mockVMClient := mockvmclient.NewMockInterface(ctrl)
mockVMSSVMClient := mockvmssvmclient.NewMockInterface(ctrl)
provider.azureManager.azClient.virtualMachinesClient = mockVMClient
provider.azureManager.azClient.virtualMachineScaleSetsClient = mockVMSSClient
provider.azureManager.azClient.virtualMachineScaleSetVMsClient = mockVMSSVMClient

// // Simulate VMSS instances
unregisteredVMSSInstance := &apiv1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: "unregistered-vmss-node",
},
Spec: apiv1.NodeSpec{
ProviderID: "azure:///subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/unregistered-vmss-instance-id/virtualMachines/0",
},
}
// Mock responses to simulate that the instance belongs to a VMSS not in any registered ASG
expectedVMSSVMs := newTestVMSSVMList(1)
mockVMSSVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup, "unregistered-vmss-instance-id", gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes()

// Call HasInstance and check the result
hasInstance, err := provider.azureManager.azureCache.HasInstance(unregisteredVMSSInstance.Spec.ProviderID)
assert.False(t, hasInstance)
assert.Equal(t, cloudprovider.ErrNotImplemented, err)
}

func TestHasInstanceProviderIDErrorValidation(t *testing.T) {
provider := newTestProvider(t)
// Test case: Node with an empty ProviderID
nodeWithoutValidProviderID := &apiv1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: "test-node",
},
Spec: apiv1.NodeSpec{
ProviderID: "",
},
}
_, err := provider.HasInstance(nodeWithoutValidProviderID)
assert.Equal(t, "ProviderID for node: test-node is empty, skipped", err.Error())

// Test cases: Nodes with invalid ProviderID prefixes
invalidProviderIDs := []string{
"aazure://",
"kubemark://",
"kwok://",
"incorrect!",
}

for _, providerID := range invalidProviderIDs {
invalidProviderIDNode := &apiv1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: "test-node",
},
Spec: apiv1.NodeSpec{
ProviderID: providerID,
},
}
_, err := provider.HasInstance(invalidProviderIDNode)
assert.Equal(t, "invalid azure ProviderID prefix for node: test-node, skipped", err.Error())
}
}

func TestMixedNodeGroups(t *testing.T) {
ctrl := gomock.NewController(t)
provider := newTestProvider(t)
Expand Down Expand Up @@ -188,57 +311,66 @@ func TestMixedNodeGroups(t *testing.T) {
func TestNodeGroupForNode(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
orchestrationModes := [2]compute.OrchestrationMode{compute.Uniform, compute.Flexible}
orchestrationModes := []compute.OrchestrationMode{compute.Uniform, compute.Flexible}

expectedVMSSVMs := newTestVMSSVMList(3)
expectedVMs := newTestVMList(3)

for _, orchMode := range orchestrationModes {
expectedScaleSets := newTestVMSSList(3, "test-asg", "eastus", orchMode)
provider := newTestProvider(t)
mockVMSSClient := mockvmssclient.NewMockInterface(ctrl)
mockVMSSClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup).Return(expectedScaleSets, nil)
provider.azureManager.azClient.virtualMachineScaleSetsClient = mockVMSSClient
mockVMClient := mockvmclient.NewMockInterface(ctrl)
provider.azureManager.azClient.virtualMachinesClient = mockVMClient
mockVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup).Return(expectedVMs, nil).AnyTimes()

if orchMode == compute.Uniform {
mockVMSSVMClient := mockvmssvmclient.NewMockInterface(ctrl)
mockVMSSVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup, "test-asg", gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes()
provider.azureManager.azClient.virtualMachineScaleSetVMsClient = mockVMSSVMClient
} else {

provider.azureManager.config.EnableVmssFlex = true
mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), "test-asg").Return(expectedVMs, nil).AnyTimes()

}

registered := provider.azureManager.RegisterNodeGroup(
newTestScaleSet(provider.azureManager, testASG))
provider.azureManager.explicitlyConfigured[testASG] = true
assert.True(t, registered)
assert.Equal(t, len(provider.NodeGroups()), 1)

node := newApiNode(orchMode, 0)
// refresh cache
provider.azureManager.forceRefresh()
group, err := provider.NodeGroupForNode(node)
assert.NoError(t, err)
assert.NotNil(t, group, "Group should not be nil")
assert.Equal(t, group.Id(), testASG)
assert.Equal(t, group.MinSize(), 1)
assert.Equal(t, group.MaxSize(), 5)

// test node in cluster that is not in a group managed by cluster autoscaler
nodeNotInGroup := &apiv1.Node{
Spec: apiv1.NodeSpec{
ProviderID: azurePrefix + "/subscriptions/subscripion/resourceGroups/test-resource-group/providers/Microsoft.Compute/virtualMachines/test-instance-id-not-in-group",
},
}
group, err = provider.NodeGroupForNode(nodeNotInGroup)
assert.NoError(t, err)
assert.Nil(t, group)
t.Run(fmt.Sprintf("OrchestrationMode_%v", orchMode), func(t *testing.T) {
expectedScaleSets := newTestVMSSList(3, "test-asg", "eastus", orchMode)
provider := newTestProvider(t)
mockVMSSClient := mockvmssclient.NewMockInterface(ctrl)
mockVMSSClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup).Return(expectedScaleSets, nil)
provider.azureManager.azClient.virtualMachineScaleSetsClient = mockVMSSClient
mockVMClient := mockvmclient.NewMockInterface(ctrl)
provider.azureManager.azClient.virtualMachinesClient = mockVMClient
mockVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup).Return(expectedVMs, nil).AnyTimes()

if orchMode == compute.Uniform {
mockVMSSVMClient := mockvmssvmclient.NewMockInterface(ctrl)
mockVMSSVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup, "test-asg", gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes()
provider.azureManager.azClient.virtualMachineScaleSetVMsClient = mockVMSSVMClient
} else {
provider.azureManager.config.EnableVmssFlex = true
mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), "test-asg").Return(expectedVMs, nil).AnyTimes()
}

registered := provider.azureManager.RegisterNodeGroup(
newTestScaleSet(provider.azureManager, "test-asg"))
provider.azureManager.explicitlyConfigured["test-asg"] = true
assert.True(t, registered)
assert.Equal(t, len(provider.NodeGroups()), 1)

node := newApiNode(orchMode, 0)
// refresh cache
provider.azureManager.forceRefresh()
group, err := provider.NodeGroupForNode(node)
assert.NoError(t, err)
assert.NotNil(t, group, "Group should not be nil")
assert.Equal(t, group.Id(), "test-asg")
assert.Equal(t, group.MinSize(), 1)
assert.Equal(t, group.MaxSize(), 5)

hasInstance, err := provider.HasInstance(node)
assert.True(t, hasInstance)
assert.NoError(t, err)

// test node in cluster that is not in a group managed by cluster autoscaler
nodeNotInGroup := &apiv1.Node{
Spec: apiv1.NodeSpec{
ProviderID: "azure:///subscriptions/subscription/resourceGroups/test-resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/test/virtualMachines/test-instance-id-not-in-group",
},
}
group, err = provider.NodeGroupForNode(nodeNotInGroup)
assert.NoError(t, err)
assert.Nil(t, group)

hasInstance, err = provider.HasInstance(nodeNotInGroup)
assert.False(t, hasInstance)
assert.Error(t, err)
assert.Equal(t, err, cloudprovider.ErrNotImplemented)
})
}
}

Expand Down
1 change: 0 additions & 1 deletion cluster-autoscaler/cloudprovider/azure/azure_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ func (m *AzureManager) buildNodeGroupFromSpec(spec string) (cloudprovider.NodeGr
if err != nil {
return nil, fmt.Errorf("failed to parse node group spec: %v", err)
}

vmsPoolSet := m.azureCache.getVMsPoolSet()
if _, ok := vmsPoolSet[s.Name]; ok {
return NewVMsPool(s, m), nil
Expand Down

0 comments on commit 1ae8bcc

Please sign in to comment.