diff --git a/azure/interfaces.go b/azure/interfaces.go index c089e96cd36..d6aa632b733 100644 --- a/azure/interfaces.go +++ b/azure/interfaces.go @@ -62,6 +62,7 @@ type NetworkDescriber interface { SetSubnet(infrav1.SubnetSpec) IsIPv6Enabled() bool ControlPlaneRouteTable() infrav1.RouteTable + APIServerLB() *infrav1.LoadBalancerSpec APIServerLBName() string APIServerLBPoolName(string) string IsAPIServerPrivate() bool diff --git a/azure/mock_azure/azure_mock.go b/azure/mock_azure/azure_mock.go index faa711f250b..ba7053081b1 100644 --- a/azure/mock_azure/azure_mock.go +++ b/azure/mock_azure/azure_mock.go @@ -305,6 +305,20 @@ func (m *MockNetworkDescriber) EXPECT() *MockNetworkDescriberMockRecorder { return m.recorder } +// APIServerLB mocks base method. +func (m *MockNetworkDescriber) APIServerLB() *v1beta1.LoadBalancerSpec { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "APIServerLB") + ret0, _ := ret[0].(*v1beta1.LoadBalancerSpec) + return ret0 +} + +// APIServerLB indicates an expected call of APIServerLB. +func (mr *MockNetworkDescriberMockRecorder) APIServerLB() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "APIServerLB", reflect.TypeOf((*MockNetworkDescriber)(nil).APIServerLB)) +} + // APIServerLBName mocks base method. func (m *MockNetworkDescriber) APIServerLBName() string { m.ctrl.T.Helper() @@ -866,6 +880,20 @@ func (m *MockClusterScoper) EXPECT() *MockClusterScoperMockRecorder { return m.recorder } +// APIServerLB mocks base method. +func (m *MockClusterScoper) APIServerLB() *v1beta1.LoadBalancerSpec { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "APIServerLB") + ret0, _ := ret[0].(*v1beta1.LoadBalancerSpec) + return ret0 +} + +// APIServerLB indicates an expected call of APIServerLB. +func (mr *MockClusterScoperMockRecorder) APIServerLB() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "APIServerLB", reflect.TypeOf((*MockClusterScoper)(nil).APIServerLB)) +} + // APIServerLBName mocks base method. func (m *MockClusterScoper) APIServerLBName() string { m.ctrl.T.Helper() diff --git a/azure/scope/machine.go b/azure/scope/machine.go index b5bc0b91ecb..1309c20443f 100644 --- a/azure/scope/machine.go +++ b/azure/scope/machine.go @@ -40,6 +40,7 @@ import ( "sigs.k8s.io/cluster-api-provider-azure/azure" "sigs.k8s.io/cluster-api-provider-azure/azure/services/availabilitysets" "sigs.k8s.io/cluster-api-provider-azure/azure/services/disks" + "sigs.k8s.io/cluster-api-provider-azure/azure/services/inboundnatrules" "sigs.k8s.io/cluster-api-provider-azure/azure/services/resourceskus" "sigs.k8s.io/cluster-api-provider-azure/azure/services/virtualmachines" "sigs.k8s.io/cluster-api-provider-azure/util/futures" @@ -193,16 +194,25 @@ func (m *MachineScope) PublicIPSpecs() []azure.PublicIPSpec { } // InboundNatSpecs returns the inbound NAT specs. -func (m *MachineScope) InboundNatSpecs() []azure.InboundNatSpec { +func (m *MachineScope) InboundNatSpecs(portsInUse map[int32]struct{}) []azure.ResourceSpecGetter { + // The existing inbound NAT rules are needed in order to find an available SSH port for each new inbound NAT rule. if m.Role() == infrav1.ControlPlane { - return []azure.InboundNatSpec{ - { - Name: m.Name(), - LoadBalancerName: m.APIServerLBName(), - }, + spec := &inboundnatrules.InboundNatSpec{ + Name: m.Name(), + ResourceGroup: m.ResourceGroup(), + LoadBalancerName: m.APIServerLBName(), + FrontendIPConfigurationID: nil, + PortsInUse: portsInUse, } + if frontEndIPs := m.APIServerLB().FrontendIPs; len(frontEndIPs) > 0 { + ipConfig := frontEndIPs[0].Name + id := azure.FrontendIPConfigID(m.SubscriptionID(), m.ResourceGroup(), m.APIServerLBName(), ipConfig) + spec.FrontendIPConfigurationID = to.StringPtr(id) + } + + return []azure.ResourceSpecGetter{spec} } - return []azure.InboundNatSpec{} + return []azure.ResourceSpecGetter{} } // NICSpecs returns the network interface specs. diff --git a/azure/scope/machine_test.go b/azure/scope/machine_test.go index d2f4c232eb1..5400020c616 100644 --- a/azure/scope/machine_test.go +++ b/azure/scope/machine_test.go @@ -18,6 +18,7 @@ package scope import ( "context" + "fmt" "reflect" "testing" @@ -32,6 +33,7 @@ import ( infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" "sigs.k8s.io/cluster-api-provider-azure/azure" "sigs.k8s.io/cluster-api-provider-azure/azure/services/disks" + "sigs.k8s.io/cluster-api-provider-azure/azure/services/inboundnatrules" ) func TestMachineScope_Name(t *testing.T) { @@ -283,7 +285,7 @@ func TestMachineScope_InboundNatSpecs(t *testing.T) { tests := []struct { name string machineScope MachineScope - want []azure.InboundNatSpec + want []azure.ResourceSpecGetter }{ { name: "returns empty when infra is not control plane", @@ -295,7 +297,7 @@ func TestMachineScope_InboundNatSpecs(t *testing.T) { }, }, }, - want: []azure.InboundNatSpec{}, + want: []azure.ResourceSpecGetter{}, }, { name: "returns InboundNatSpec when infra is control plane", @@ -313,29 +315,58 @@ func TestMachineScope_InboundNatSpecs(t *testing.T) { }, }, ClusterScoper: &ClusterScope{ + AzureClients: AzureClients{ + EnvironmentSettings: auth.EnvironmentSettings{ + Values: map[string]string{ + auth.SubscriptionID: "123", + }, + }, + }, AzureCluster: &infrav1.AzureCluster{ Spec: infrav1.AzureClusterSpec{ + ResourceGroup: "my-rg", + SubscriptionID: "123", NetworkSpec: infrav1.NetworkSpec{ APIServerLB: infrav1.LoadBalancerSpec{ Name: "foo-loadbalancer", + FrontendIPs: []infrav1.FrontendIP{ + { + Name: "foo-frontend-ip", + }, + }, }, }, }, }, }, }, - want: []azure.InboundNatSpec{ - { - Name: "machine-name", - LoadBalancerName: "foo-loadbalancer", + want: []azure.ResourceSpecGetter{ + &inboundnatrules.InboundNatSpec{ + Name: "machine-name", + LoadBalancerName: "foo-loadbalancer", + ResourceGroup: "my-rg", + FrontendIPConfigurationID: to.StringPtr(azure.FrontendIPConfigID("123", "my-rg", "foo-loadbalancer", "foo-frontend-ip")), + PortsInUse: make(map[int32]struct{}), }, }, }, } for _, tt := range tests { + tt := tt t.Run(tt.name, func(t *testing.T) { - if got := tt.machineScope.InboundNatSpecs(); !reflect.DeepEqual(got, tt.want) { - t.Errorf("InboundNatSpecs() = %v, want %v", got, tt.want) + t.Parallel() + if got := tt.machineScope.InboundNatSpecs(make(map[int32]struct{})); !reflect.DeepEqual(got, tt.want) { + gotArray := "[ " + for _, spec := range got { + gotArray += fmt.Sprintf("%+v ", spec) + } + gotArray += "]" + wantArray := "[ " + for _, spec := range tt.want { + wantArray += fmt.Sprintf("%+v ", spec) + } + wantArray += "]" + t.Errorf("InboundNatSpecs([]converters.ExistingInboundNatSpec{}) = %s, want %s", gotArray, wantArray) } }) } diff --git a/azure/scope/managedcontrolplane.go b/azure/scope/managedcontrolplane.go index 396212259e7..900964b8d6d 100644 --- a/azure/scope/managedcontrolplane.go +++ b/azure/scope/managedcontrolplane.go @@ -316,6 +316,11 @@ func (s *ManagedControlPlaneScope) IsVnetManaged() bool { return true } +// APIServerLBName returns the API Server LB spec. +func (s *ManagedControlPlaneScope) APIServerLB() *infrav1.LoadBalancerSpec { + return nil // does not apply for AKS +} + // APIServerLBName returns the API Server LB name. func (s *ManagedControlPlaneScope) APIServerLBName() string { return "" // does not apply for AKS diff --git a/azure/services/bastionhosts/mocks_bastionhosts/bastionhosts_mock.go b/azure/services/bastionhosts/mocks_bastionhosts/bastionhosts_mock.go index 0e88cd26deb..7c5915acd41 100644 --- a/azure/services/bastionhosts/mocks_bastionhosts/bastionhosts_mock.go +++ b/azure/services/bastionhosts/mocks_bastionhosts/bastionhosts_mock.go @@ -52,6 +52,20 @@ func (m *MockBastionScope) EXPECT() *MockBastionScopeMockRecorder { return m.recorder } +// APIServerLB mocks base method. +func (m *MockBastionScope) APIServerLB() *v1beta1.LoadBalancerSpec { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "APIServerLB") + ret0, _ := ret[0].(*v1beta1.LoadBalancerSpec) + return ret0 +} + +// APIServerLB indicates an expected call of APIServerLB. +func (mr *MockBastionScopeMockRecorder) APIServerLB() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "APIServerLB", reflect.TypeOf((*MockBastionScope)(nil).APIServerLB)) +} + // APIServerLBName mocks base method. func (m *MockBastionScope) APIServerLBName() string { m.ctrl.T.Helper() diff --git a/azure/services/inboundnatrules/client.go b/azure/services/inboundnatrules/client.go index d166f1fb8fb..36055f32f17 100644 --- a/azure/services/inboundnatrules/client.go +++ b/azure/services/inboundnatrules/client.go @@ -18,19 +18,28 @@ package inboundnatrules import ( "context" + "encoding/json" + "fmt" "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2021-02-01/network" "github.com/Azure/go-autorest/autorest" + azureautorest "github.com/Azure/go-autorest/autorest/azure" + "github.com/pkg/errors" + infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" "sigs.k8s.io/cluster-api-provider-azure/azure" + "sigs.k8s.io/cluster-api-provider-azure/util/reconciler" "sigs.k8s.io/cluster-api-provider-azure/util/tele" ) // client wraps go-sdk. type client interface { - Get(context.Context, string, string, string) (network.InboundNatRule, error) - CreateOrUpdate(context.Context, string, string, string, network.InboundNatRule) error - Delete(context.Context, string, string, string) error + List(context.Context, string, string) (result []network.InboundNatRule, err error) + Get(context.Context, azure.ResourceSpecGetter) (result interface{}, err error) + CreateOrUpdateAsync(context.Context, azure.ResourceSpecGetter, interface{}) (result interface{}, future azureautorest.FutureAPI, err error) + DeleteAsync(context.Context, azure.ResourceSpecGetter) (future azureautorest.FutureAPI, err error) + IsDone(context.Context, azureautorest.FutureAPI) (isDone bool, err error) + Result(context.Context, azureautorest.FutureAPI, string) (result interface{}, err error) } // azureClient contains the Azure go-sdk Client. @@ -42,11 +51,13 @@ var _ client = (*azureClient)(nil) // newClient creates a new inbound NAT rules client from subscription ID. func newClient(auth azure.Authorizer) *azureClient { - c := newInboundNatRulesClient(auth.SubscriptionID(), auth.BaseURI(), auth.Authorizer()) - return &azureClient{c} + inboundNatRulesClient := newInboundNatRulesClient(auth.SubscriptionID(), auth.BaseURI(), auth.Authorizer()) + return &azureClient{ + inboundnatrules: inboundNatRulesClient, + } } -// newLoadbalancersClient creates a new inbound NAT rules client from subscription ID. +// newInboundNatClient creates a new inbound NAT rules client from subscription ID. func newInboundNatRulesClient(subscriptionID string, baseURI string, authorizer autorest.Authorizer) network.InboundNatRulesClient { inboundNatRulesClient := network.NewInboundNatRulesClientWithBaseURI(baseURI, subscriptionID) azure.SetAutoRestClientDefaults(&inboundNatRulesClient.Client, authorizer) @@ -54,43 +65,134 @@ func newInboundNatRulesClient(subscriptionID string, baseURI string, authorizer } // Get gets the specified inbound NAT rules. -func (ac *azureClient) Get(ctx context.Context, resourceGroupName, lbName, inboundNatRuleName string) (network.InboundNatRule, error) { - ctx, _, done := tele.StartSpanWithLogger(ctx, "inboundnatrules.AzureClient.Get") +func (ac *azureClient) Get(ctx context.Context, spec azure.ResourceSpecGetter) (result interface{}, err error) { + ctx, _, done := tele.StartSpanWithLogger(ctx, "inboundnatrules.azureClient.Get") + defer done() + + return ac.inboundnatrules.Get(ctx, spec.ResourceGroupName(), spec.OwnerResourceName(), spec.ResourceName(), "") +} + +// List returns all inbound NAT rules on a load balancer. +func (ac *azureClient) List(ctx context.Context, resourceGroupName, lbName string) (result []network.InboundNatRule, err error) { + ctx, _, done := tele.StartSpanWithLogger(ctx, "inboundnatrules.azureClient.List") defer done() - return ac.inboundnatrules.Get(ctx, resourceGroupName, lbName, inboundNatRuleName, "") + iter, err := ac.inboundnatrules.ListComplete(ctx, resourceGroupName, lbName) + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("could not list inbound NAT rules for load balancer %s", lbName)) + } + + var natRules []network.InboundNatRule + for iter.NotDone() { + natRules = append(natRules, iter.Value()) + if err := iter.NextWithContext(ctx); err != nil { + return natRules, errors.Wrap(err, "could not iterate inbound NAT rules") + } + } + + return natRules, nil } -// CreateOrUpdate creates or updates a inbound NAT rules. -func (ac *azureClient) CreateOrUpdate(ctx context.Context, resourceGroupName string, lbName string, inboundNatRuleName string, inboundNatRuleParameters network.InboundNatRule) error { - ctx, _, done := tele.StartSpanWithLogger(ctx, "inboundnatrules.AzureClient.CreateOrUpdate") +// CreateOrUpdateAsync creates or updates an inbound NAT rule asynchronously. +// It sends a PUT request to Azure and if accepted without error, the func will return a Future which can be used to track the ongoing +// progress of the operation. +func (ac *azureClient) CreateOrUpdateAsync(ctx context.Context, spec azure.ResourceSpecGetter, parameters interface{}) (result interface{}, future azureautorest.FutureAPI, err error) { + ctx, _, done := tele.StartSpanWithLogger(ctx, "inboundnatrules.azureClient.CreateOrUpdateAsync") defer done() - future, err := ac.inboundnatrules.CreateOrUpdate(ctx, resourceGroupName, lbName, inboundNatRuleName, inboundNatRuleParameters) + natRule, ok := parameters.(network.InboundNatRule) + if !ok { + return nil, nil, errors.Errorf("%T is not a network.InboundNatRule", parameters) + } + + createFuture, err := ac.inboundnatrules.CreateOrUpdate(ctx, spec.ResourceGroupName(), spec.OwnerResourceName(), spec.ResourceName(), natRule) if err != nil { - return err + return nil, nil, err } - err = future.WaitForCompletionRef(ctx, ac.inboundnatrules.Client) + + ctx, cancel := context.WithTimeout(ctx, reconciler.DefaultAzureCallTimeout) + defer cancel() + + err = createFuture.WaitForCompletionRef(ctx, ac.inboundnatrules.Client) if err != nil { - return err + // if an error occurs, return the future. + // this means the long-running operation didn't finish in the specified timeout. + return nil, &createFuture, err } - _, err = future.Result(ac.inboundnatrules) - return err + + result, err = createFuture.Result(ac.inboundnatrules) + // if the operation completed, return a nil future + return result, nil, err } -// Delete deletes the specified inbound NAT rules. -func (ac *azureClient) Delete(ctx context.Context, resourceGroupName, lbName, inboundNatRuleName string) error { - ctx, _, done := tele.StartSpanWithLogger(ctx, "inboundnatrules.AzureClient.Delete") +// DeleteAsync deletes an inbound NAT rule asynchronously. DeleteAsync sends a DELETE +// request to Azure and if accepted without error, the func will return a Future which can be used to track the ongoing +// progress of the operation. +func (ac *azureClient) DeleteAsync(ctx context.Context, spec azure.ResourceSpecGetter) (future azureautorest.FutureAPI, err error) { + ctx, _, done := tele.StartSpanWithLogger(ctx, "inboundnatrules.azureClient.DeleteAsync") defer done() - future, err := ac.inboundnatrules.Delete(ctx, resourceGroupName, lbName, inboundNatRuleName) + deleteFuture, err := ac.inboundnatrules.Delete(ctx, spec.ResourceGroupName(), spec.OwnerResourceName(), spec.ResourceName()) if err != nil { - return err + return nil, err } - err = future.WaitForCompletionRef(ctx, ac.inboundnatrules.Client) + + ctx, cancel := context.WithTimeout(ctx, reconciler.DefaultAzureCallTimeout) + defer cancel() + + err = deleteFuture.WaitForCompletionRef(ctx, ac.inboundnatrules.Client) + if err != nil { + // if an error occurs, return the future. + // this means the long-running operation didn't finish in the specified timeout. + return &deleteFuture, err + } + _, err = deleteFuture.Result(ac.inboundnatrules) + // if the operation completed, return a nil future. + return nil, err +} + +// IsDone returns true if the long-running operation has completed. +func (ac *azureClient) IsDone(ctx context.Context, future azureautorest.FutureAPI) (isDone bool, err error) { + ctx, _, done := tele.StartSpanWithLogger(ctx, "inboundnatrules.azureClient.IsDone") + defer done() + + isDone, err = future.DoneWithContext(ctx, ac.inboundnatrules) if err != nil { - return err + return false, errors.Wrap(err, "failed checking if the operation was complete") + } + + return isDone, nil +} + +// Result fetches the result of a long-running operation future. +func (ac *azureClient) Result(ctx context.Context, future azureautorest.FutureAPI, futureType string) (result interface{}, err error) { + _, _, done := tele.StartSpanWithLogger(ctx, "inboundnatrules.azureClient.Result") + defer done() + + if future == nil { + return nil, errors.Errorf("cannot get result from nil future") + } + + switch futureType { + case infrav1.PutFuture: + // Marshal and Unmarshal the future to put it into the correct future type so we can access the Result function. + // Unfortunately the FutureAPI can't be casted directly to InboundNatRulesCreateOrUpdateFuture because it is a azureautorest.Future, which doesn't implement the Result function. See PR #1686 for discussion on alternatives. + // It was converted back to a generic azureautorest.Future from the CAPZ infrav1.Future type stored in Status: https://github.com/kubernetes-sigs/cluster-api-provider-azure/blob/main/azure/converters/futures.go#L49. + var createFuture *network.InboundNatRulesCreateOrUpdateFuture + jsonData, err := future.MarshalJSON() + if err != nil { + return nil, errors.Wrap(err, "failed to marshal future") + } + if err := json.Unmarshal(jsonData, &createFuture); err != nil { + return nil, errors.Wrap(err, "failed to unmarshal future data") + } + return (*createFuture).Result(ac.inboundnatrules) + + case infrav1.DeleteFuture: + // Delete does not return a result inbound NAT rule + return nil, nil + + default: + return nil, errors.Errorf("unknown future type %q", futureType) } - _, err = future.Result(ac.inboundnatrules) - return err } diff --git a/azure/services/inboundnatrules/inboundnatrules.go b/azure/services/inboundnatrules/inboundnatrules.go index 553af9724ea..2d838ea02b6 100644 --- a/azure/services/inboundnatrules/inboundnatrules.go +++ b/azure/services/inboundnatrules/inboundnatrules.go @@ -19,136 +19,101 @@ package inboundnatrules import ( "context" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2021-02-01/network" - "github.com/Azure/go-autorest/autorest/to" "github.com/pkg/errors" + + infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" "sigs.k8s.io/cluster-api-provider-azure/azure" - "sigs.k8s.io/cluster-api-provider-azure/azure/services/loadbalancers" + "sigs.k8s.io/cluster-api-provider-azure/azure/services/async" + "sigs.k8s.io/cluster-api-provider-azure/util/reconciler" "sigs.k8s.io/cluster-api-provider-azure/util/tele" ) +const serviceName = "inboundnatrules" + // InboundNatScope defines the scope interface for an inbound NAT service. type InboundNatScope interface { azure.ClusterDescriber - InboundNatSpecs() []azure.InboundNatSpec + azure.AsyncStatusUpdater + APIServerLBName() string + InboundNatSpecs(map[int32]struct{}) []azure.ResourceSpecGetter } // Service provides operations on Azure resources. type Service struct { Scope InboundNatScope client - loadBalancersClient loadbalancers.Client + async.Reconciler } // New creates a new service. func New(scope InboundNatScope) *Service { + client := newClient(scope) return &Service{ - Scope: scope, - client: newClient(scope), - loadBalancersClient: loadbalancers.NewClient(scope), + Scope: scope, + client: client, + Reconciler: async.New(scope, client, client), } } // Reconcile gets/creates/updates an inbound NAT rule. func (s *Service) Reconcile(ctx context.Context) error { - ctx, log, done := tele.StartSpanWithLogger(ctx, "inboundnatrules.Service.Reconcile") + ctx, _, done := tele.StartSpanWithLogger(ctx, "inboundnatrules.Service.Reconcile") defer done() - for _, inboundNatSpec := range s.Scope.InboundNatSpecs() { - log.V(2).Info("creating inbound NAT rule", "NAT rule", inboundNatSpec.Name) - - lb, err := s.loadBalancersClient.Get(ctx, s.Scope.ResourceGroup(), inboundNatSpec.LoadBalancerName) - if err != nil { - return errors.Wrapf(err, "failed to get Load Balancer %s", inboundNatSpec.LoadBalancerName) - } + ctx, cancel := context.WithTimeout(ctx, reconciler.DefaultAzureServiceReconcileTimeout) + defer cancel() - if lb.LoadBalancerPropertiesFormat == nil || lb.FrontendIPConfigurations == nil || lb.InboundNatRules == nil { - return errors.Errorf("Could not get existing inbound NAT rules from load balancer %s properties", to.String(lb.Name)) - } - - ports := make(map[int32]struct{}) - if s.natRuleExists(ports)(ctx, *lb.InboundNatRules, inboundNatSpec.Name) { - // Inbound NAT Rule already exists, nothing to do here. - continue - } - - sshFrontendPort, err := s.getAvailablePort(ctx, ports) - if err != nil { - return errors.Wrapf(err, "failed to find available SSH Frontend port for NAT Rule %s in load balancer %s", inboundNatSpec.Name, to.String(lb.Name)) - } + existingRules, err := s.client.List(ctx, s.Scope.ResourceGroup(), s.Scope.APIServerLBName()) + if err != nil { + result := errors.Wrapf(err, "failed to get existing NAT rules") + s.Scope.UpdatePutStatus(infrav1.InboundNATRulesReadyCondition, serviceName, result) + return result + } - rule := network.InboundNatRule{ - Name: to.StringPtr(inboundNatSpec.Name), - InboundNatRulePropertiesFormat: &network.InboundNatRulePropertiesFormat{ - BackendPort: to.Int32Ptr(22), - EnableFloatingIP: to.BoolPtr(false), - IdleTimeoutInMinutes: to.Int32Ptr(4), - FrontendIPConfiguration: &network.SubResource{ - ID: (*lb.FrontendIPConfigurations)[0].ID, - }, - Protocol: network.TransportProtocolTCP, - FrontendPort: &sshFrontendPort, - }, - } - log.V(3).Info("Creating rule %s using port %d", "NAT rule", inboundNatSpec.Name, "port", sshFrontendPort) + portsInUse := make(map[int32]struct{}) + for _, rule := range existingRules { + portsInUse[*rule.InboundNatRulePropertiesFormat.FrontendPort] = struct{}{} // Mark frontend port as in use + } - err = s.client.CreateOrUpdate(ctx, s.Scope.ResourceGroup(), to.String(lb.Name), inboundNatSpec.Name, rule) - if err != nil { - return errors.Wrapf(err, "failed to create inbound NAT rule %s", inboundNatSpec.Name) + // We go through the list of InboundNatSpecs to reconcile each one, independently of the result of the previous one. + // If multiple errors occur, we return the most pressing one. + // Order of precedence (highest -> lowest) is: error that is not an operationNotDoneError (i.e. error creating) -> operationNotDoneError (i.e. creating in progress) -> no error (i.e. created) + var result error + for _, natRule := range s.Scope.InboundNatSpecs(portsInUse) { + // If we are creating multiple inbound NAT rules, we could have a collision in finding an available frontend port since the newly created rule takes an available port, and we do not update portsInUse in the specs. + // It doesn't matter in this case since we only create one rule per machine, but for multiple rules, we could end up restarting the Reconcile function each time to get the updated available ports. + // TODO: We can update the available ports and recompute the specs each time, or alternatively, we could deterministically calculate the ports we plan on using to avoid collisions, i.e. rule #1 uses the first available port, rule #2 uses the second available port, etc. + if _, err := s.CreateResource(ctx, natRule, serviceName); err != nil { + if !azure.IsOperationNotDoneError(err) || result == nil { + result = err + } } - - log.V(2).Info("successfully created inbound NAT rule", "NAT rule", inboundNatSpec.Name) } - return nil + + s.Scope.UpdatePutStatus(infrav1.InboundNATRulesReadyCondition, serviceName, result) + return result } // Delete deletes the inbound NAT rule with the provided name. func (s *Service) Delete(ctx context.Context) error { - ctx, log, done := tele.StartSpanWithLogger(ctx, "inboundnatrules.Service.Delete") + ctx, _, done := tele.StartSpanWithLogger(ctx, "inboundnatrules.Service.Delete") defer done() - for _, inboundNatSpec := range s.Scope.InboundNatSpecs() { - log.V(2).Info("deleting inbound NAT rule", "NAT rule", inboundNatSpec.Name) - err := s.client.Delete(ctx, s.Scope.ResourceGroup(), inboundNatSpec.LoadBalancerName, inboundNatSpec.Name) - if err != nil && !azure.ResourceNotFound(err) { - return errors.Wrapf(err, "failed to delete inbound NAT rule %s", inboundNatSpec.Name) - } - - log.V(2).Info("successfully deleted inbound NAT rule", "NAT rule", inboundNatSpec.Name) - } - return nil -} - -func (s *Service) natRuleExists(ports map[int32]struct{}) func(context.Context, []network.InboundNatRule, string) bool { - return func(ctx context.Context, rules []network.InboundNatRule, name string) bool { - _, log, done := tele.StartSpanWithLogger(ctx, "inboundnatrules.Service.natRuleExists") - defer done() + ctx, cancel := context.WithTimeout(ctx, reconciler.DefaultAzureServiceReconcileTimeout) + defer cancel() - for _, v := range rules { - if to.String(v.Name) == name { - log.V(2).Info("NAT rule already exists", "NAT rule", name) - return true - } - ports[*v.InboundNatRulePropertiesFormat.FrontendPort] = struct{}{} - } - return false - } -} - -func (s *Service) getAvailablePort(ctx context.Context, ports map[int32]struct{}) (int32, error) { - _, log, done := tele.StartSpanWithLogger(ctx, "inboundnatrules.Service.getAvailablePort") - defer done() + var result error - var i int32 = 22 - if _, ok := ports[22]; ok { - for i = 2201; i < 2220; i++ { - if _, ok := ports[i]; !ok { - log.V(2).Info("Found available port", "port", i) - return i, nil + // We go through the list of InboundNatSpecs to delete each one, independently of the result of the previous one. + // If multiple errors occur, we return the most pressing one. + // Order of precedence (highest -> lowest) is: error that is not an operationNotDoneError (i.e. error deleting) -> operationNotDoneError (i.e. deleting in progress) -> no error (i.e. deleted) + for _, natRule := range s.Scope.InboundNatSpecs(make(map[int32]struct{})) { + if err := s.DeleteResource(ctx, natRule, serviceName); err != nil { + if !azure.IsOperationNotDoneError(err) || result == nil { + result = err } } - return i, errors.Errorf("No available SSH Frontend ports") } - log.V(2).Info("Found available port", "port", i) - return i, nil + s.Scope.UpdateDeleteStatus(infrav1.InboundNATRulesReadyCondition, serviceName, result) + return result } diff --git a/azure/services/inboundnatrules/inboundnatrules_test.go b/azure/services/inboundnatrules/inboundnatrules_test.go index 47a3d3200e5..fd534e687e7 100644 --- a/azure/services/inboundnatrules/inboundnatrules_test.go +++ b/azure/services/inboundnatrules/inboundnatrules_test.go @@ -27,194 +27,130 @@ import ( "github.com/golang/mock/gomock" . "github.com/onsi/gomega" "k8s.io/utils/pointer" + + infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" "sigs.k8s.io/cluster-api-provider-azure/azure" + "sigs.k8s.io/cluster-api-provider-azure/azure/services/async/mock_async" "sigs.k8s.io/cluster-api-provider-azure/azure/services/inboundnatrules/mock_inboundnatrules" - "sigs.k8s.io/cluster-api-provider-azure/azure/services/loadbalancers/mock_loadbalancers" gomockinternal "sigs.k8s.io/cluster-api-provider-azure/internal/test/matchers/gomock" ) +var ( + fakeLBName = "my-lb-1" + fakeGroupName = "my-rg" + + noPortsInUse = getFakeExistingPortsInUse([]int{}) + noExistingRules = []network.InboundNatRule{} + fakeExistingRules = []network.InboundNatRule{ + { + Name: pointer.StringPtr("other-machine-nat-rule"), + ID: pointer.StringPtr("some-natrules-id"), + InboundNatRulePropertiesFormat: &network.InboundNatRulePropertiesFormat{ + FrontendPort: to.Int32Ptr(22), + }, + }, + { + Name: pointer.StringPtr("other-machine-nat-rule-2"), + ID: pointer.StringPtr("some-natrules-id-2"), + InboundNatRulePropertiesFormat: &network.InboundNatRulePropertiesFormat{ + FrontendPort: to.Int32Ptr(2201), + }, + }, + } + somePortsInUse = getFakeExistingPortsInUse([]int{22, 2201}) + + fakeNatSpecWithNoExisting = InboundNatSpec{ + Name: "my-machine-1", + LoadBalancerName: "my-lb-1", + ResourceGroup: fakeGroupName, + FrontendIPConfigurationID: to.StringPtr("frontend-ip-config-id-2"), + PortsInUse: noPortsInUse, + } + fakeNatSpec = InboundNatSpec{ + Name: "my-machine-2", + LoadBalancerName: "my-lb-2", + ResourceGroup: fakeGroupName, + FrontendIPConfigurationID: to.StringPtr("frontend-ip-config-id-2"), + PortsInUse: somePortsInUse, + } + internalError = autorest.NewErrorWithResponse("", "", &http.Response{StatusCode: 500}, "Internal Server Error") +) + +func getFakeExistingPortsInUse(usedPorts []int) map[int32]struct{} { + portsInUse := make(map[int32]struct{}) + for _, port := range usedPorts { + portsInUse[int32(port)] = struct{}{} + } + + return portsInUse +} + func TestReconcileInboundNATRule(t *testing.T) { testcases := []struct { name string expectedError string expect func(s *mock_inboundnatrules.MockInboundNatScopeMockRecorder, m *mock_inboundnatrules.MockclientMockRecorder, - mLoadBalancer *mock_loadbalancers.MockClientMockRecorder) + r *mock_async.MockReconcilerMockRecorder) }{ { - name: "NAT rule successfully created", + name: "NAT rule successfully created with with no existing rules", expectedError: "", expect: func(s *mock_inboundnatrules.MockInboundNatScopeMockRecorder, m *mock_inboundnatrules.MockclientMockRecorder, - mLoadBalancer *mock_loadbalancers.MockClientMockRecorder) { - s.InboundNatSpecs().Return([]azure.InboundNatSpec{ - { - Name: "my-machine", - LoadBalancerName: "my-lb", - }, - }) - s.ResourceGroup().AnyTimes().Return("my-rg") - s.Location().AnyTimes().Return("fake-location") + r *mock_async.MockReconcilerMockRecorder) { + s.ResourceGroup().AnyTimes().Return(fakeGroupName) + s.APIServerLBName().AnyTimes().Return(fakeLBName) + m.List(gomockinternal.AContext(), fakeGroupName, fakeLBName).Return(noExistingRules, nil) + s.InboundNatSpecs(noPortsInUse).Return([]azure.ResourceSpecGetter{&fakeNatSpecWithNoExisting}) gomock.InOrder( - mLoadBalancer.Get(gomockinternal.AContext(), "my-rg", "my-lb").Return(network.LoadBalancer{ - Name: to.StringPtr("my-lb"), - ID: pointer.StringPtr("my-lb-id"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - FrontendIPConfigurations: &[]network.FrontendIPConfiguration{ - { - ID: to.StringPtr("frontend-ip-config-id"), - }, - }, - InboundNatRules: &[]network.InboundNatRule{}, - }}, nil), - m.CreateOrUpdate(gomockinternal.AContext(), "my-rg", "my-lb", "my-machine", network.InboundNatRule{ - Name: pointer.StringPtr("my-machine"), - InboundNatRulePropertiesFormat: &network.InboundNatRulePropertiesFormat{ - FrontendPort: to.Int32Ptr(22), - BackendPort: to.Int32Ptr(22), - EnableFloatingIP: to.BoolPtr(false), - IdleTimeoutInMinutes: to.Int32Ptr(4), - FrontendIPConfiguration: &network.SubResource{ - ID: to.StringPtr("frontend-ip-config-id"), - }, - Protocol: network.TransportProtocolTCP, - }, - })) + r.CreateResource(gomockinternal.AContext(), &fakeNatSpecWithNoExisting, serviceName).Return(nil, nil), + s.UpdatePutStatus(infrav1.InboundNATRulesReadyCondition, serviceName, nil), + ) }, }, { - name: "fail to get LB", - expectedError: "failed to get Load Balancer my-public-lb: #: Internal Server Error: StatusCode=500", + name: "NAT rule successfully created with with existing rules", + expectedError: "", expect: func(s *mock_inboundnatrules.MockInboundNatScopeMockRecorder, m *mock_inboundnatrules.MockclientMockRecorder, - mLoadBalancer *mock_loadbalancers.MockClientMockRecorder) { - s.InboundNatSpecs().Return([]azure.InboundNatSpec{ - { - Name: "my-machine", - LoadBalancerName: "my-public-lb", - }, - }) - s.ResourceGroup().AnyTimes().Return("my-rg") - s.Location().AnyTimes().Return("fake-location") + r *mock_async.MockReconcilerMockRecorder) { + s.ResourceGroup().AnyTimes().Return(fakeGroupName) + s.APIServerLBName().AnyTimes().Return("my-lb") + m.List(gomockinternal.AContext(), fakeGroupName, "my-lb").Return(fakeExistingRules, nil) + s.InboundNatSpecs(somePortsInUse).Return([]azure.ResourceSpecGetter{&fakeNatSpec}) gomock.InOrder( - mLoadBalancer.Get(gomockinternal.AContext(), "my-rg", "my-public-lb"). - Return(network.LoadBalancer{}, autorest.NewErrorWithResponse("", "", &http.Response{StatusCode: 500}, "Internal Server Error"))) + r.CreateResource(gomockinternal.AContext(), &fakeNatSpec, serviceName).Return(nil, nil), + s.UpdatePutStatus(infrav1.InboundNATRulesReadyCondition, serviceName, nil), + ) }, }, { - name: "fail to create NAT rule", - expectedError: "failed to create inbound NAT rule my-machine: #: Internal Server Error: StatusCode=500", + name: "fail to get existing rules", + expectedError: "failed to get existing NAT rules: #: Internal Server Error: StatusCode=500", expect: func(s *mock_inboundnatrules.MockInboundNatScopeMockRecorder, m *mock_inboundnatrules.MockclientMockRecorder, - mLoadBalancer *mock_loadbalancers.MockClientMockRecorder) { - s.InboundNatSpecs().Return([]azure.InboundNatSpec{ - { - Name: "my-machine", - LoadBalancerName: "my-public-lb", - }, - }) - s.ResourceGroup().AnyTimes().Return("my-rg") - s.Location().AnyTimes().Return("fake-location") - gomock.InOrder( - mLoadBalancer.Get(gomockinternal.AContext(), "my-rg", "my-public-lb").Return(network.LoadBalancer{ - Name: to.StringPtr("my-public-lb"), - ID: pointer.StringPtr("my-public-lb-id"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - FrontendIPConfigurations: &[]network.FrontendIPConfiguration{ - { - ID: to.StringPtr("frontend-ip-config-id"), - }, - }, - InboundNatRules: &[]network.InboundNatRule{ - { - Name: pointer.StringPtr("other-machine-nat-rule"), - ID: pointer.StringPtr("some-natrules-id"), - InboundNatRulePropertiesFormat: &network.InboundNatRulePropertiesFormat{ - FrontendPort: to.Int32Ptr(22), - }, - }, - { - Name: pointer.StringPtr("other-machine-nat-rule-2"), - ID: pointer.StringPtr("some-natrules-id-2"), - InboundNatRulePropertiesFormat: &network.InboundNatRulePropertiesFormat{ - FrontendPort: to.Int32Ptr(2201), - }, - }, - }, - }}, nil), - m.CreateOrUpdate(gomockinternal.AContext(), "my-rg", "my-public-lb", "my-machine", network.InboundNatRule{ - Name: pointer.StringPtr("my-machine"), - InboundNatRulePropertiesFormat: &network.InboundNatRulePropertiesFormat{ - FrontendPort: to.Int32Ptr(2202), - BackendPort: to.Int32Ptr(22), - EnableFloatingIP: to.BoolPtr(false), - IdleTimeoutInMinutes: to.Int32Ptr(4), - FrontendIPConfiguration: &network.SubResource{ - ID: to.StringPtr("frontend-ip-config-id"), - }, - Protocol: network.TransportProtocolTCP, - }, - }). - Return(autorest.NewErrorWithResponse("", "", &http.Response{StatusCode: 500}, "Internal Server Error"))) + r *mock_async.MockReconcilerMockRecorder) { + s.ResourceGroup().AnyTimes().Return(fakeGroupName) + s.APIServerLBName().AnyTimes().Return("my-lb") + m.List(gomockinternal.AContext(), fakeGroupName, "my-lb").Return(nil, internalError) + s.UpdatePutStatus(infrav1.InboundNATRulesReadyCondition, serviceName, gomockinternal.ErrStrEq("failed to get existing NAT rules: #: Internal Server Error: StatusCode=500")) }, }, { - name: "NAT rule already exists", - expectedError: "", + name: "fail to create NAT rule", + expectedError: "#: Internal Server Error: StatusCode=500", expect: func(s *mock_inboundnatrules.MockInboundNatScopeMockRecorder, m *mock_inboundnatrules.MockclientMockRecorder, - mLoadBalancer *mock_loadbalancers.MockClientMockRecorder) { - s.InboundNatSpecs().Return([]azure.InboundNatSpec{ - { - Name: "my-machine-nat-rule", - LoadBalancerName: "my-public-lb", - }, - { - Name: "my-other-nat-rule", - LoadBalancerName: "my-other-public-lb", - }, - }) - s.ResourceGroup().AnyTimes().Return("my-rg") - s.Location().AnyTimes().Return("fake-location") + r *mock_async.MockReconcilerMockRecorder) { + s.ResourceGroup().AnyTimes().Return(fakeGroupName) + s.APIServerLBName().AnyTimes().Return("my-lb") + m.List(gomockinternal.AContext(), fakeGroupName, "my-lb").Return(fakeExistingRules, nil) + s.InboundNatSpecs(somePortsInUse).Return([]azure.ResourceSpecGetter{&fakeNatSpec}) gomock.InOrder( - mLoadBalancer.Get(gomockinternal.AContext(), "my-rg", "my-public-lb").Return(network.LoadBalancer{ - Name: to.StringPtr("my-public-lb"), - ID: pointer.StringPtr("my-public-lb-id"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - FrontendIPConfigurations: &[]network.FrontendIPConfiguration{ - { - ID: to.StringPtr("frontend-ip-config-id"), - }, - }, - InboundNatRules: &[]network.InboundNatRule{ - { - Name: pointer.StringPtr("my-machine-nat-rule"), - ID: pointer.StringPtr("some-natrules-id"), - InboundNatRulePropertiesFormat: &network.InboundNatRulePropertiesFormat{ - FrontendPort: to.Int32Ptr(22), - }, - }, - { - Name: pointer.StringPtr("other-machine-nat-rule-2"), - ID: pointer.StringPtr("some-natrules-id-2"), - InboundNatRulePropertiesFormat: &network.InboundNatRulePropertiesFormat{ - FrontendPort: to.Int32Ptr(2201), - }, - }, - }, - }}, nil), - mLoadBalancer.Get(gomockinternal.AContext(), "my-rg", "my-other-public-lb").Return(network.LoadBalancer{ - Name: to.StringPtr("my-other-public-lb"), - ID: pointer.StringPtr("my-public-lb-id"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - FrontendIPConfigurations: &[]network.FrontendIPConfiguration{ - { - ID: to.StringPtr("frontend-ip-config-id"), - }, - }, - InboundNatRules: &[]network.InboundNatRule{}, - }}, nil), - m.CreateOrUpdate(gomockinternal.AContext(), "my-rg", "my-other-public-lb", "my-other-nat-rule", gomock.AssignableToTypeOf(network.InboundNatRule{}))) + r.CreateResource(gomockinternal.AContext(), &fakeNatSpec, serviceName).Return(nil, internalError), + s.UpdatePutStatus(infrav1.InboundNATRulesReadyCondition, serviceName, internalError), + ) }, }, } @@ -228,14 +164,14 @@ func TestReconcileInboundNATRule(t *testing.T) { defer mockCtrl.Finish() scopeMock := mock_inboundnatrules.NewMockInboundNatScope(mockCtrl) clientMock := mock_inboundnatrules.NewMockclient(mockCtrl) - loadBalancerMock := mock_loadbalancers.NewMockClient(mockCtrl) + asyncMock := mock_async.NewMockReconciler(mockCtrl) - tc.expect(scopeMock.EXPECT(), clientMock.EXPECT(), loadBalancerMock.EXPECT()) + tc.expect(scopeMock.EXPECT(), clientMock.EXPECT(), asyncMock.EXPECT()) s := &Service{ - Scope: scopeMock, - client: clientMock, - loadBalancersClient: loadBalancerMock, + Scope: scopeMock, + client: clientMock, + Reconciler: asyncMock, } err := s.Reconcile(context.TODO()) @@ -254,53 +190,34 @@ func TestDeleteNetworkInterface(t *testing.T) { name string expectedError string expect func(s *mock_inboundnatrules.MockInboundNatScopeMockRecorder, - m *mock_inboundnatrules.MockclientMockRecorder, mLoadBalancer *mock_loadbalancers.MockClientMockRecorder) + m *mock_inboundnatrules.MockclientMockRecorder, r *mock_async.MockReconcilerMockRecorder) }{ { name: "successfully delete an existing NAT rule", expectedError: "", expect: func(s *mock_inboundnatrules.MockInboundNatScopeMockRecorder, - m *mock_inboundnatrules.MockclientMockRecorder, mLoadBalancer *mock_loadbalancers.MockClientMockRecorder) { - s.InboundNatSpecs().Return([]azure.InboundNatSpec{ - { - Name: "azure-md-0", - LoadBalancerName: "my-public-lb", - }, - }) - s.ResourceGroup().AnyTimes().Return("my-rg") - m.Delete(gomockinternal.AContext(), "my-rg", "my-public-lb", "azure-md-0") - }, - }, - { - name: "NAT rule already deleted", - expectedError: "", - expect: func(s *mock_inboundnatrules.MockInboundNatScopeMockRecorder, - m *mock_inboundnatrules.MockclientMockRecorder, mLoadBalancer *mock_loadbalancers.MockClientMockRecorder) { - s.InboundNatSpecs().Return([]azure.InboundNatSpec{ - { - Name: "azure-md-1", - LoadBalancerName: "my-public-lb", - }, - }) - s.ResourceGroup().AnyTimes().Return("my-rg") - m.Delete(gomockinternal.AContext(), "my-rg", "my-public-lb", "azure-md-1"). - Return(autorest.NewErrorWithResponse("", "", &http.Response{StatusCode: 404}, "Not found")) + m *mock_inboundnatrules.MockclientMockRecorder, r *mock_async.MockReconcilerMockRecorder) { + s.InboundNatSpecs(noPortsInUse).Return([]azure.ResourceSpecGetter{&fakeNatSpecWithNoExisting}) + s.ResourceGroup().AnyTimes().Return(fakeGroupName) + s.APIServerLBName().AnyTimes().Return(fakeLBName) + gomock.InOrder( + r.DeleteResource(gomockinternal.AContext(), &fakeNatSpecWithNoExisting, serviceName).Return(nil), + s.UpdateDeleteStatus(infrav1.InboundNATRulesReadyCondition, serviceName, nil), + ) }, }, { name: "NAT rule deletion fails", - expectedError: "failed to delete inbound NAT rule azure-md-2: #: Internal Server Error: StatusCode=500", + expectedError: "#: Internal Server Error: StatusCode=500", expect: func(s *mock_inboundnatrules.MockInboundNatScopeMockRecorder, - m *mock_inboundnatrules.MockclientMockRecorder, mLoadBalancer *mock_loadbalancers.MockClientMockRecorder) { - s.InboundNatSpecs().Return([]azure.InboundNatSpec{ - { - Name: "azure-md-2", - LoadBalancerName: "my-public-lb", - }, - }) - s.ResourceGroup().AnyTimes().Return("my-rg") - m.Delete(gomockinternal.AContext(), "my-rg", "my-public-lb", "azure-md-2"). - Return(autorest.NewErrorWithResponse("", "", &http.Response{StatusCode: 500}, "Internal Server Error")) + m *mock_inboundnatrules.MockclientMockRecorder, r *mock_async.MockReconcilerMockRecorder) { + s.InboundNatSpecs(noPortsInUse).Return([]azure.ResourceSpecGetter{&fakeNatSpecWithNoExisting}) + s.ResourceGroup().AnyTimes().Return(fakeGroupName) + s.APIServerLBName().AnyTimes().Return(fakeLBName) + gomock.InOrder( + r.DeleteResource(gomockinternal.AContext(), &fakeNatSpecWithNoExisting, serviceName).Return(internalError), + s.UpdateDeleteStatus(infrav1.InboundNATRulesReadyCondition, serviceName, internalError), + ) }, }, } @@ -314,14 +231,14 @@ func TestDeleteNetworkInterface(t *testing.T) { defer mockCtrl.Finish() scopeMock := mock_inboundnatrules.NewMockInboundNatScope(mockCtrl) clientMock := mock_inboundnatrules.NewMockclient(mockCtrl) - loadBalancerMock := mock_loadbalancers.NewMockClient(mockCtrl) + asyncMock := mock_async.NewMockReconciler(mockCtrl) - tc.expect(scopeMock.EXPECT(), clientMock.EXPECT(), loadBalancerMock.EXPECT()) + tc.expect(scopeMock.EXPECT(), clientMock.EXPECT(), asyncMock.EXPECT()) s := &Service{ - Scope: scopeMock, - client: clientMock, - loadBalancersClient: loadBalancerMock, + Scope: scopeMock, + client: clientMock, + Reconciler: asyncMock, } err := s.Delete(context.TODO()) @@ -334,197 +251,3 @@ func TestDeleteNetworkInterface(t *testing.T) { }) } } - -func TestNatRuleExists(t *testing.T) { - testcases := []struct { - name string - ruleName string - existingRules []network.InboundNatRule - expectedResult bool - expectedPorts map[int32]struct{} - expect func(s *mock_inboundnatrules.MockInboundNatScopeMockRecorder, - m *mock_inboundnatrules.MockclientMockRecorder, mLoadBalancer *mock_loadbalancers.MockClientMockRecorder) - }{ - { - name: "Rule exists", - ruleName: "my-rule", - existingRules: []network.InboundNatRule{ - { - InboundNatRulePropertiesFormat: &network.InboundNatRulePropertiesFormat{ - FrontendPort: to.Int32Ptr(2201), - }, - Name: to.StringPtr("some-rule"), - }, - { - InboundNatRulePropertiesFormat: &network.InboundNatRulePropertiesFormat{ - FrontendPort: to.Int32Ptr(22), - }, - Name: to.StringPtr("my-rule"), - }, - }, - expectedResult: true, - expect: func(s *mock_inboundnatrules.MockInboundNatScopeMockRecorder, - m *mock_inboundnatrules.MockclientMockRecorder, mLoadBalancer *mock_loadbalancers.MockClientMockRecorder) { - }, - }, - { - name: "Rule doesn't exist", - ruleName: "my-rule", - existingRules: []network.InboundNatRule{ - { - InboundNatRulePropertiesFormat: &network.InboundNatRulePropertiesFormat{ - FrontendPort: to.Int32Ptr(22), - }, - Name: to.StringPtr("other-rule"), - }, - { - InboundNatRulePropertiesFormat: &network.InboundNatRulePropertiesFormat{ - FrontendPort: to.Int32Ptr(2205), - }, - Name: to.StringPtr("other-rule-2"), - }, - }, - expectedResult: false, - expectedPorts: map[int32]struct{}{ - 22: {}, - 2205: {}, - }, - expect: func(s *mock_inboundnatrules.MockInboundNatScopeMockRecorder, - m *mock_inboundnatrules.MockclientMockRecorder, mLoadBalancer *mock_loadbalancers.MockClientMockRecorder) { - }, - }, - { - name: "No rules exist", - ruleName: "my-rule", - existingRules: []network.InboundNatRule{}, - expectedResult: false, - expectedPorts: map[int32]struct{}{}, - expect: func(s *mock_inboundnatrules.MockInboundNatScopeMockRecorder, - m *mock_inboundnatrules.MockclientMockRecorder, mLoadBalancer *mock_loadbalancers.MockClientMockRecorder) { - }, - }, - } - - for _, tc := range testcases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - g := NewWithT(t) - t.Parallel() - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - scopeMock := mock_inboundnatrules.NewMockInboundNatScope(mockCtrl) - clientMock := mock_inboundnatrules.NewMockclient(mockCtrl) - loadBalancerMock := mock_loadbalancers.NewMockClient(mockCtrl) - - tc.expect(scopeMock.EXPECT(), clientMock.EXPECT(), loadBalancerMock.EXPECT()) - - s := &Service{ - Scope: scopeMock, - client: clientMock, - loadBalancersClient: loadBalancerMock, - } - - ports := make(map[int32]struct{}) - exists := s.natRuleExists(ports)(context.TODO(), tc.existingRules, tc.ruleName) - g.Expect(exists).To(Equal(tc.expectedResult)) - if !exists { - g.Expect(ports).To(Equal(tc.expectedPorts)) - } - }) - } -} - -func TestGetAvailablePort(t *testing.T) { - testcases := []struct { - name string - portsInput map[int32]struct{} - expectedError string - expectedPortResult int32 - expect func(s *mock_inboundnatrules.MockInboundNatScopeMockRecorder, - m *mock_inboundnatrules.MockclientMockRecorder, mLoadBalancer *mock_loadbalancers.MockClientMockRecorder) - }{ - { - name: "Empty ports", - portsInput: map[int32]struct{}{}, - expectedError: "", - expectedPortResult: 22, - expect: func(s *mock_inboundnatrules.MockInboundNatScopeMockRecorder, - m *mock_inboundnatrules.MockclientMockRecorder, mLoadBalancer *mock_loadbalancers.MockClientMockRecorder) { - }, - }, - { - name: "22 taken", - portsInput: map[int32]struct{}{ - 22: {}, - }, - expectedError: "", - expectedPortResult: 2201, - expect: func(s *mock_inboundnatrules.MockInboundNatScopeMockRecorder, - m *mock_inboundnatrules.MockclientMockRecorder, mLoadBalancer *mock_loadbalancers.MockClientMockRecorder) { - }, - }, - { - name: "Existing ports", - portsInput: map[int32]struct{}{ - 22: {}, - 2201: {}, - 2202: {}, - 2204: {}, - }, - expectedError: "", - expectedPortResult: 2203, - expect: func(s *mock_inboundnatrules.MockInboundNatScopeMockRecorder, - m *mock_inboundnatrules.MockclientMockRecorder, mLoadBalancer *mock_loadbalancers.MockClientMockRecorder) { - }, - }, - { - name: "No ports available", - portsInput: getFullPortsMap(), - expectedError: "No available SSH Frontend ports", - expectedPortResult: 0, - expect: func(s *mock_inboundnatrules.MockInboundNatScopeMockRecorder, - m *mock_inboundnatrules.MockclientMockRecorder, mLoadBalancer *mock_loadbalancers.MockClientMockRecorder) { - }, - }, - } - - for _, tc := range testcases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - g := NewWithT(t) - t.Parallel() - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - scopeMock := mock_inboundnatrules.NewMockInboundNatScope(mockCtrl) - clientMock := mock_inboundnatrules.NewMockclient(mockCtrl) - loadBalancerMock := mock_loadbalancers.NewMockClient(mockCtrl) - - tc.expect(scopeMock.EXPECT(), clientMock.EXPECT(), loadBalancerMock.EXPECT()) - - s := &Service{ - Scope: scopeMock, - client: clientMock, - loadBalancersClient: loadBalancerMock, - } - - res, err := s.getAvailablePort(context.TODO(), tc.portsInput) - if tc.expectedError != "" { - g.Expect(err).To(HaveOccurred()) - g.Expect(err).To(MatchError(tc.expectedError)) - } else { - g.Expect(err).NotTo(HaveOccurred()) - g.Expect(res).To(Equal(tc.expectedPortResult)) - } - }) - } -} - -func getFullPortsMap() map[int32]struct{} { - res := map[int32]struct{}{ - 22: {}, - } - for i := 2201; i < 2220; i++ { - res[int32(i)] = struct{}{} - } - return res -} diff --git a/azure/services/inboundnatrules/mock_inboundnatrules/client_mock.go b/azure/services/inboundnatrules/mock_inboundnatrules/client_mock.go index f406b497e02..7b7ea0ecaaf 100644 --- a/azure/services/inboundnatrules/mock_inboundnatrules/client_mock.go +++ b/azure/services/inboundnatrules/mock_inboundnatrules/client_mock.go @@ -25,7 +25,9 @@ import ( reflect "reflect" network "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2021-02-01/network" + azure "github.com/Azure/go-autorest/autorest/azure" gomock "github.com/golang/mock/gomock" + azure0 "sigs.k8s.io/cluster-api-provider-azure/azure" ) // Mockclient is a mock of client interface. @@ -51,45 +53,93 @@ func (m *Mockclient) EXPECT() *MockclientMockRecorder { return m.recorder } -// CreateOrUpdate mocks base method. -func (m *Mockclient) CreateOrUpdate(arg0 context.Context, arg1, arg2, arg3 string, arg4 network.InboundNatRule) error { +// CreateOrUpdateAsync mocks base method. +func (m *Mockclient) CreateOrUpdateAsync(arg0 context.Context, arg1 azure0.ResourceSpecGetter, arg2 interface{}) (interface{}, azure.FutureAPI, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateOrUpdate", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "CreateOrUpdateAsync", arg0, arg1, arg2) + ret0, _ := ret[0].(interface{}) + ret1, _ := ret[1].(azure.FutureAPI) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } -// CreateOrUpdate indicates an expected call of CreateOrUpdate. -func (mr *MockclientMockRecorder) CreateOrUpdate(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { +// CreateOrUpdateAsync indicates an expected call of CreateOrUpdateAsync. +func (mr *MockclientMockRecorder) CreateOrUpdateAsync(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateOrUpdate", reflect.TypeOf((*Mockclient)(nil).CreateOrUpdate), arg0, arg1, arg2, arg3, arg4) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateOrUpdateAsync", reflect.TypeOf((*Mockclient)(nil).CreateOrUpdateAsync), arg0, arg1, arg2) } -// Delete mocks base method. -func (m *Mockclient) Delete(arg0 context.Context, arg1, arg2, arg3 string) error { +// DeleteAsync mocks base method. +func (m *Mockclient) DeleteAsync(arg0 context.Context, arg1 azure0.ResourceSpecGetter) (azure.FutureAPI, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Delete", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "DeleteAsync", arg0, arg1) + ret0, _ := ret[0].(azure.FutureAPI) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// Delete indicates an expected call of Delete. -func (mr *MockclientMockRecorder) Delete(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +// DeleteAsync indicates an expected call of DeleteAsync. +func (mr *MockclientMockRecorder) DeleteAsync(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*Mockclient)(nil).Delete), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAsync", reflect.TypeOf((*Mockclient)(nil).DeleteAsync), arg0, arg1) } // Get mocks base method. -func (m *Mockclient) Get(arg0 context.Context, arg1, arg2, arg3 string) (network.InboundNatRule, error) { +func (m *Mockclient) Get(arg0 context.Context, arg1 azure0.ResourceSpecGetter) (interface{}, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(network.InboundNatRule) + ret := m.ctrl.Call(m, "Get", arg0, arg1) + ret0, _ := ret[0].(interface{}) ret1, _ := ret[1].(error) return ret0, ret1 } // Get indicates an expected call of Get. -func (mr *MockclientMockRecorder) Get(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockclientMockRecorder) Get(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*Mockclient)(nil).Get), arg0, arg1) +} + +// IsDone mocks base method. +func (m *Mockclient) IsDone(arg0 context.Context, arg1 azure.FutureAPI) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsDone", arg0, arg1) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsDone indicates an expected call of IsDone. +func (mr *MockclientMockRecorder) IsDone(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsDone", reflect.TypeOf((*Mockclient)(nil).IsDone), arg0, arg1) +} + +// List mocks base method. +func (m *Mockclient) List(arg0 context.Context, arg1, arg2 string) ([]network.InboundNatRule, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List", arg0, arg1, arg2) + ret0, _ := ret[0].([]network.InboundNatRule) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// List indicates an expected call of List. +func (mr *MockclientMockRecorder) List(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*Mockclient)(nil).List), arg0, arg1, arg2) +} + +// Result mocks base method. +func (m *Mockclient) Result(arg0 context.Context, arg1 azure.FutureAPI, arg2 string) (interface{}, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Result", arg0, arg1, arg2) + ret0, _ := ret[0].(interface{}) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Result indicates an expected call of Result. +func (mr *MockclientMockRecorder) Result(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*Mockclient)(nil).Get), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Result", reflect.TypeOf((*Mockclient)(nil).Result), arg0, arg1, arg2) } diff --git a/azure/services/inboundnatrules/mock_inboundnatrules/inboundnatrules_mock.go b/azure/services/inboundnatrules/mock_inboundnatrules/inboundnatrules_mock.go index 72b637170b1..9cb5e890bc5 100644 --- a/azure/services/inboundnatrules/mock_inboundnatrules/inboundnatrules_mock.go +++ b/azure/services/inboundnatrules/mock_inboundnatrules/inboundnatrules_mock.go @@ -27,6 +27,7 @@ import ( gomock "github.com/golang/mock/gomock" v1beta1 "sigs.k8s.io/cluster-api-provider-azure/api/v1beta1" azure "sigs.k8s.io/cluster-api-provider-azure/azure" + v1beta10 "sigs.k8s.io/cluster-api/api/v1beta1" ) // MockInboundNatScope is a mock of InboundNatScope interface. @@ -52,6 +53,20 @@ func (m *MockInboundNatScope) EXPECT() *MockInboundNatScopeMockRecorder { return m.recorder } +// APIServerLBName mocks base method. +func (m *MockInboundNatScope) APIServerLBName() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "APIServerLBName") + ret0, _ := ret[0].(string) + return ret0 +} + +// APIServerLBName indicates an expected call of APIServerLBName. +func (mr *MockInboundNatScopeMockRecorder) APIServerLBName() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "APIServerLBName", reflect.TypeOf((*MockInboundNatScope)(nil).APIServerLBName)) +} + // AdditionalTags mocks base method. func (m *MockInboundNatScope) AdditionalTags() v1beta1.Tags { m.ctrl.T.Helper() @@ -178,6 +193,18 @@ func (mr *MockInboundNatScopeMockRecorder) ClusterName() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterName", reflect.TypeOf((*MockInboundNatScope)(nil).ClusterName)) } +// DeleteLongRunningOperationState mocks base method. +func (m *MockInboundNatScope) DeleteLongRunningOperationState(arg0, arg1 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DeleteLongRunningOperationState", arg0, arg1) +} + +// DeleteLongRunningOperationState indicates an expected call of DeleteLongRunningOperationState. +func (mr *MockInboundNatScopeMockRecorder) DeleteLongRunningOperationState(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteLongRunningOperationState", reflect.TypeOf((*MockInboundNatScope)(nil).DeleteLongRunningOperationState), arg0, arg1) +} + // FailureDomains mocks base method. func (m *MockInboundNatScope) FailureDomains() []string { m.ctrl.T.Helper() @@ -192,6 +219,20 @@ func (mr *MockInboundNatScopeMockRecorder) FailureDomains() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FailureDomains", reflect.TypeOf((*MockInboundNatScope)(nil).FailureDomains)) } +// GetLongRunningOperationState mocks base method. +func (m *MockInboundNatScope) GetLongRunningOperationState(arg0, arg1 string) *v1beta1.Future { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLongRunningOperationState", arg0, arg1) + ret0, _ := ret[0].(*v1beta1.Future) + return ret0 +} + +// GetLongRunningOperationState indicates an expected call of GetLongRunningOperationState. +func (mr *MockInboundNatScopeMockRecorder) GetLongRunningOperationState(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLongRunningOperationState", reflect.TypeOf((*MockInboundNatScope)(nil).GetLongRunningOperationState), arg0, arg1) +} + // HashKey mocks base method. func (m *MockInboundNatScope) HashKey() string { m.ctrl.T.Helper() @@ -207,17 +248,17 @@ func (mr *MockInboundNatScopeMockRecorder) HashKey() *gomock.Call { } // InboundNatSpecs mocks base method. -func (m *MockInboundNatScope) InboundNatSpecs() []azure.InboundNatSpec { +func (m *MockInboundNatScope) InboundNatSpecs(arg0 map[int32]struct{}) []azure.ResourceSpecGetter { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "InboundNatSpecs") - ret0, _ := ret[0].([]azure.InboundNatSpec) + ret := m.ctrl.Call(m, "InboundNatSpecs", arg0) + ret0, _ := ret[0].([]azure.ResourceSpecGetter) return ret0 } // InboundNatSpecs indicates an expected call of InboundNatSpecs. -func (mr *MockInboundNatScopeMockRecorder) InboundNatSpecs() *gomock.Call { +func (mr *MockInboundNatScopeMockRecorder) InboundNatSpecs(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InboundNatSpecs", reflect.TypeOf((*MockInboundNatScope)(nil).InboundNatSpecs)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InboundNatSpecs", reflect.TypeOf((*MockInboundNatScope)(nil).InboundNatSpecs), arg0) } // Location mocks base method. @@ -248,6 +289,18 @@ func (mr *MockInboundNatScopeMockRecorder) ResourceGroup() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResourceGroup", reflect.TypeOf((*MockInboundNatScope)(nil).ResourceGroup)) } +// SetLongRunningOperationState mocks base method. +func (m *MockInboundNatScope) SetLongRunningOperationState(arg0 *v1beta1.Future) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetLongRunningOperationState", arg0) +} + +// SetLongRunningOperationState indicates an expected call of SetLongRunningOperationState. +func (mr *MockInboundNatScopeMockRecorder) SetLongRunningOperationState(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLongRunningOperationState", reflect.TypeOf((*MockInboundNatScope)(nil).SetLongRunningOperationState), arg0) +} + // SubscriptionID mocks base method. func (m *MockInboundNatScope) SubscriptionID() string { m.ctrl.T.Helper() @@ -275,3 +328,39 @@ func (mr *MockInboundNatScopeMockRecorder) TenantID() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TenantID", reflect.TypeOf((*MockInboundNatScope)(nil).TenantID)) } + +// UpdateDeleteStatus mocks base method. +func (m *MockInboundNatScope) UpdateDeleteStatus(arg0 v1beta10.ConditionType, arg1 string, arg2 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdateDeleteStatus", arg0, arg1, arg2) +} + +// UpdateDeleteStatus indicates an expected call of UpdateDeleteStatus. +func (mr *MockInboundNatScopeMockRecorder) UpdateDeleteStatus(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateDeleteStatus", reflect.TypeOf((*MockInboundNatScope)(nil).UpdateDeleteStatus), arg0, arg1, arg2) +} + +// UpdatePatchStatus mocks base method. +func (m *MockInboundNatScope) UpdatePatchStatus(arg0 v1beta10.ConditionType, arg1 string, arg2 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatePatchStatus", arg0, arg1, arg2) +} + +// UpdatePatchStatus indicates an expected call of UpdatePatchStatus. +func (mr *MockInboundNatScopeMockRecorder) UpdatePatchStatus(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePatchStatus", reflect.TypeOf((*MockInboundNatScope)(nil).UpdatePatchStatus), arg0, arg1, arg2) +} + +// UpdatePutStatus mocks base method. +func (m *MockInboundNatScope) UpdatePutStatus(arg0 v1beta10.ConditionType, arg1 string, arg2 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatePutStatus", arg0, arg1, arg2) +} + +// UpdatePutStatus indicates an expected call of UpdatePutStatus. +func (mr *MockInboundNatScopeMockRecorder) UpdatePutStatus(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePutStatus", reflect.TypeOf((*MockInboundNatScope)(nil).UpdatePutStatus), arg0, arg1, arg2) +} diff --git a/azure/services/inboundnatrules/spec.go b/azure/services/inboundnatrules/spec.go new file mode 100644 index 00000000000..e941643f36f --- /dev/null +++ b/azure/services/inboundnatrules/spec.go @@ -0,0 +1,99 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package inboundnatrules + +import ( + "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2021-02-01/network" + "github.com/Azure/go-autorest/autorest/to" + "github.com/pkg/errors" +) + +// InboundNatSpec defines the specification for an inbound NAT rule. +type InboundNatSpec struct { + Name string + LoadBalancerName string + ResourceGroup string + FrontendIPConfigurationID *string + PortsInUse map[int32]struct{} +} + +// ResourceName returns the name of the inbound NAT rule. +func (s *InboundNatSpec) ResourceName() string { + return s.Name +} + +// ResourceGroupName returns the name of the resource group. +func (s *InboundNatSpec) ResourceGroupName() string { + return s.ResourceGroup +} + +// OwnerResourceName returns the name of the load balancer associated with an inbound NAT rule. +func (s *InboundNatSpec) OwnerResourceName() string { + return s.LoadBalancerName +} + +// Parameters returns the parameters for the inbound NAT rule. +func (s *InboundNatSpec) Parameters(existing interface{}) (parameters interface{}, err error) { + if existing != nil { + if _, ok := existing.(network.InboundNatRule); !ok { + return nil, errors.Errorf("%T is not a network.InboundNatRule", existing) + } + + return nil, nil + } + + if s.FrontendIPConfigurationID == nil { + return nil, errors.Errorf("FrontendIPConfigurationID is not set") + } + + sshFrontendPort, err := getAvailablePort(s.PortsInUse) + if err != nil { + return nil, errors.Wrapf(err, "failed to find available SSH Frontend port for NAT Rule %s in load balancer %s", s.ResourceName(), s.OwnerResourceName()) + } + + rule := network.InboundNatRule{ + Name: to.StringPtr(s.ResourceName()), + InboundNatRulePropertiesFormat: &network.InboundNatRulePropertiesFormat{ + BackendPort: to.Int32Ptr(22), + EnableFloatingIP: to.BoolPtr(false), + IdleTimeoutInMinutes: to.Int32Ptr(4), + FrontendIPConfiguration: &network.SubResource{ + ID: s.FrontendIPConfigurationID, + }, + Protocol: network.TransportProtocolTCP, + FrontendPort: &sshFrontendPort, + }, + } + + return rule, nil +} + +func getAvailablePort(portsInUse map[int32]struct{}) (int32, error) { + // TODO: should we use a different range of ports? + var i int32 = 22 + if _, ok := portsInUse[22]; ok { + for i = 2201; i < 2220; i++ { + if _, ok := portsInUse[i]; !ok { + // Found available port + return i, nil + } + } + return i, errors.Errorf("No available SSH Frontend ports") + } + + return i, nil +} diff --git a/azure/services/inboundnatrules/spec_test.go b/azure/services/inboundnatrules/spec_test.go new file mode 100644 index 00000000000..158a14da4f8 --- /dev/null +++ b/azure/services/inboundnatrules/spec_test.go @@ -0,0 +1,90 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package inboundnatrules + +import ( + "testing" + + . "github.com/onsi/gomega" +) + +func TestGetAvailablePort(t *testing.T) { + testcases := []struct { + name string + portsInput map[int32]struct{} + expectedError string + expectedPortResult int32 + }{ + { + name: "Empty ports", + portsInput: map[int32]struct{}{}, + expectedError: "", + expectedPortResult: 22, + }, + { + name: "22 taken", + portsInput: map[int32]struct{}{ + 22: {}, + }, + expectedError: "", + expectedPortResult: 2201, + }, + { + name: "Existing ports", + portsInput: map[int32]struct{}{ + 22: {}, + 2201: {}, + 2202: {}, + 2204: {}, + }, + expectedError: "", + expectedPortResult: 2203, + }, + { + name: "No ports available", + portsInput: getFullPortsMap(), + expectedError: "No available SSH Frontend ports", + expectedPortResult: 0, + }, + } + for _, tc := range testcases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + t.Parallel() + + res, err := getAvailablePort(tc.portsInput) + if tc.expectedError != "" { + g.Expect(err).To(HaveOccurred()) + g.Expect(err).To(MatchError(tc.expectedError)) + } else { + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(res).To(Equal(tc.expectedPortResult)) + } + }) + } +} + +func getFullPortsMap() map[int32]struct{} { + res := map[int32]struct{}{ + 22: {}, + } + for i := 2201; i < 2220; i++ { + res[int32(i)] = struct{}{} + } + return res +} diff --git a/azure/services/loadbalancers/mock_loadbalancers/loadbalancers_mock.go b/azure/services/loadbalancers/mock_loadbalancers/loadbalancers_mock.go index 10031a3d074..76ce28b3399 100644 --- a/azure/services/loadbalancers/mock_loadbalancers/loadbalancers_mock.go +++ b/azure/services/loadbalancers/mock_loadbalancers/loadbalancers_mock.go @@ -52,6 +52,20 @@ func (m *MockLBScope) EXPECT() *MockLBScopeMockRecorder { return m.recorder } +// APIServerLB mocks base method. +func (m *MockLBScope) APIServerLB() *v1beta1.LoadBalancerSpec { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "APIServerLB") + ret0, _ := ret[0].(*v1beta1.LoadBalancerSpec) + return ret0 +} + +// APIServerLB indicates an expected call of APIServerLB. +func (mr *MockLBScopeMockRecorder) APIServerLB() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "APIServerLB", reflect.TypeOf((*MockLBScope)(nil).APIServerLB)) +} + // APIServerLBName mocks base method. func (m *MockLBScope) APIServerLBName() string { m.ctrl.T.Helper() diff --git a/azure/services/natgateways/mock_natgateways/natgateways_mock.go b/azure/services/natgateways/mock_natgateways/natgateways_mock.go index 53c415e5530..af547f8ebca 100644 --- a/azure/services/natgateways/mock_natgateways/natgateways_mock.go +++ b/azure/services/natgateways/mock_natgateways/natgateways_mock.go @@ -53,6 +53,20 @@ func (m *MockNatGatewayScope) EXPECT() *MockNatGatewayScopeMockRecorder { return m.recorder } +// APIServerLB mocks base method. +func (m *MockNatGatewayScope) APIServerLB() *v1beta1.LoadBalancerSpec { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "APIServerLB") + ret0, _ := ret[0].(*v1beta1.LoadBalancerSpec) + return ret0 +} + +// APIServerLB indicates an expected call of APIServerLB. +func (mr *MockNatGatewayScopeMockRecorder) APIServerLB() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "APIServerLB", reflect.TypeOf((*MockNatGatewayScope)(nil).APIServerLB)) +} + // APIServerLBName mocks base method. func (m *MockNatGatewayScope) APIServerLBName() string { m.ctrl.T.Helper() diff --git a/azure/services/securitygroups/mock_securitygroups/securitygroups_mock.go b/azure/services/securitygroups/mock_securitygroups/securitygroups_mock.go index 795a5f9bb61..d4c78c48579 100644 --- a/azure/services/securitygroups/mock_securitygroups/securitygroups_mock.go +++ b/azure/services/securitygroups/mock_securitygroups/securitygroups_mock.go @@ -52,6 +52,20 @@ func (m *MockNSGScope) EXPECT() *MockNSGScopeMockRecorder { return m.recorder } +// APIServerLB mocks base method. +func (m *MockNSGScope) APIServerLB() *v1beta1.LoadBalancerSpec { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "APIServerLB") + ret0, _ := ret[0].(*v1beta1.LoadBalancerSpec) + return ret0 +} + +// APIServerLB indicates an expected call of APIServerLB. +func (mr *MockNSGScopeMockRecorder) APIServerLB() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "APIServerLB", reflect.TypeOf((*MockNSGScope)(nil).APIServerLB)) +} + // APIServerLBName mocks base method. func (m *MockNSGScope) APIServerLBName() string { m.ctrl.T.Helper() diff --git a/azure/services/subnets/mock_subnets/subnets_mock.go b/azure/services/subnets/mock_subnets/subnets_mock.go index 936dc885d34..9026f657dfc 100644 --- a/azure/services/subnets/mock_subnets/subnets_mock.go +++ b/azure/services/subnets/mock_subnets/subnets_mock.go @@ -52,6 +52,20 @@ func (m *MockSubnetScope) EXPECT() *MockSubnetScopeMockRecorder { return m.recorder } +// APIServerLB mocks base method. +func (m *MockSubnetScope) APIServerLB() *v1beta1.LoadBalancerSpec { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "APIServerLB") + ret0, _ := ret[0].(*v1beta1.LoadBalancerSpec) + return ret0 +} + +// APIServerLB indicates an expected call of APIServerLB. +func (mr *MockSubnetScopeMockRecorder) APIServerLB() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "APIServerLB", reflect.TypeOf((*MockSubnetScope)(nil).APIServerLB)) +} + // APIServerLBName mocks base method. func (m *MockSubnetScope) APIServerLBName() string { m.ctrl.T.Helper() diff --git a/azure/types.go b/azure/types.go index b62a6db74f3..ed07c06238f 100644 --- a/azure/types.go +++ b/azure/types.go @@ -64,12 +64,6 @@ type LBSpec struct { IdleTimeoutInMinutes *int32 } -// InboundNatSpec defines the specification for an inbound NAT rule. -type InboundNatSpec struct { - Name string - LoadBalancerName string -} - // SubnetSpec defines the specification for a Subnet. type SubnetSpec struct { Name string