diff --git a/pkg/actuators/machine/actuator_test.go b/pkg/actuators/machine/actuator_test.go index d0816a4971..2dbf847bd2 100644 --- a/pkg/actuators/machine/actuator_test.go +++ b/pkg/actuators/machine/actuator_test.go @@ -1184,7 +1184,7 @@ func TestGetMachineInstances(t *testing.T) { exists: true, }, { - testcase: "has-status-search-by-id", + testcase: "has-status-search-by-id-running", providerStatus: providerconfigv1.AWSMachineProviderStatus{ InstanceID: aws.String(instanceID), }, @@ -1205,23 +1205,21 @@ func TestGetMachineInstances(t *testing.T) { exists: true, }, { - testcase: "has-status-search-by-id and machine is terminated", + testcase: "has-status-search-by-id-terminated", providerStatus: providerconfigv1.AWSMachineProviderStatus{ InstanceID: aws.String(instanceID), }, awsClientFunc: func(ctrl *gomock.Controller) awsclient.Client { mockAWSClient := mockaws.NewMockClient(ctrl) - request := &ec2.DescribeInstancesInput{ + first := mockAWSClient.EXPECT().DescribeInstances(&ec2.DescribeInstancesInput{ InstanceIds: aws.StringSlice([]string{instanceID}), - } - - mockAWSClient.EXPECT().DescribeInstances(request).Return( + }).Return( stubDescribeInstancesOutput(imageID, instanceID, ec2.InstanceStateNameTerminated), nil, ).Times(1) - request2 := &ec2.DescribeInstancesInput{ + mockAWSClient.EXPECT().DescribeInstances(&ec2.DescribeInstancesInput{ Filters: []*ec2.Filter{ { Name: awsTagFilter("Name"), @@ -1230,16 +1228,13 @@ func TestGetMachineInstances(t *testing.T) { clusterFilter(clusterID), }, - } - - mockAWSClient.EXPECT().DescribeInstances(request2).Return( - stubDescribeInstancesOutput(imageID, instanceID, ec2.InstanceStateNameTerminated), + }).Return( + &ec2.DescribeInstancesOutput{}, nil, - ).Times(1) + ).Times(1).After(first) return mockAWSClient }, - exists: false, }, } diff --git a/pkg/actuators/machine/utils.go b/pkg/actuators/machine/utils.go index 8d1298e2c9..3b095dba12 100644 --- a/pkg/actuators/machine/utils.go +++ b/pkg/actuators/machine/utils.go @@ -18,6 +18,7 @@ package machine import ( "fmt" + "strings" "github.com/golang/glog" @@ -104,20 +105,11 @@ func getExistingInstances(machine *machinev1.Machine, client awsclient.Client) ( } func getExistingInstanceByID(id string, client awsclient.Client) (*ec2.Instance, error) { - instance, err := getInstanceByID(id, client) - if err != nil { - return nil, err - } - if instance.State != nil { - if aws.StringValue(instance.State.Name) == ec2.InstanceStateNameTerminated { - return nil, fmt.Errorf("failed to getExistingInstanceByID for instance-id %s, instance is terminated", id) - } - } - return instance, nil + return getInstanceByID(id, client, existingInstanceStates()) } // getInstanceByID returns the instance with the given ID if it exists. -func getInstanceByID(id string, client awsclient.Client) (*ec2.Instance, error) { +func getInstanceByID(id string, client awsclient.Client, instanceStateFilter []*string) (*ec2.Instance, error) { if id == "" { return nil, fmt.Errorf("instance-id not specified") } @@ -141,7 +133,28 @@ func getInstanceByID(id string, client awsclient.Client) (*ec2.Instance, error) return nil, fmt.Errorf("found %d instances for instance-id %s", len(reservation.Instances), id) } - return reservation.Instances[0], nil + instance := reservation.Instances[0] + + if len(instanceStateFilter) == 0 { + return instance, nil + } + + if instance.State == nil { + return nil, fmt.Errorf("instance %s has nil state", id) + } + + actualState := aws.StringValue(instance.State.Name) + for _, allowedState := range instanceStateFilter { + if aws.StringValue(allowedState) == actualState { + return instance, nil + } + } + + allowedStates := make([]string, 0, len(instanceStateFilter)) + for _, allowedState := range instanceStateFilter { + allowedStates = append(allowedStates, aws.StringValue(allowedState)) + } + return instance, fmt.Errorf("instance %s state %q is not in %s", id, actualState, strings.Join(allowedStates, ", ")) } // getInstances returns all instances that have a tag matching our machine name,