From 5eafd954e2987f013f200efdb174093f51a2928c Mon Sep 17 00:00:00 2001 From: Fabio Bertinatto Date: Sun, 5 May 2019 19:53:57 +0200 Subject: [PATCH 1/2] Add instancey type to metadata --- pkg/cloud/metadata.go | 12 ++++++++++++ pkg/cloud/metadata_test.go | 11 +++++++++++ pkg/driver/mocks/mock_metadata_service.go | 14 ++++++++++++++ tests/e2e/pre_provsioning.go | 13 +++++++++---- 4 files changed, 46 insertions(+), 4 deletions(-) diff --git a/pkg/cloud/metadata.go b/pkg/cloud/metadata.go index e74d32cded..c373308400 100644 --- a/pkg/cloud/metadata.go +++ b/pkg/cloud/metadata.go @@ -30,12 +30,14 @@ type EC2Metadata interface { // MetadataService represents AWS metadata service. type MetadataService interface { GetInstanceID() string + GetInstanceType() string GetRegion() string GetAvailabilityZone() string } type Metadata struct { InstanceID string + InstanceType string Region string AvailabilityZone string } @@ -47,6 +49,11 @@ func (m *Metadata) GetInstanceID() string { return m.InstanceID } +// GetInstanceID returns the instance type. +func (m *Metadata) GetInstanceType() string { + return m.InstanceType +} + // GetRegion returns the region which the instance is in. func (m *Metadata) GetRegion() string { return m.Region @@ -72,6 +79,10 @@ func NewMetadataService(svc EC2Metadata) (MetadataService, error) { return nil, fmt.Errorf("could not get valid EC2 instance ID") } + if len(doc.InstanceType) == 0 { + return nil, fmt.Errorf("could not get valid EC2 instance type") + } + if len(doc.Region) == 0 { return nil, fmt.Errorf("could not get valid EC2 region") } @@ -82,6 +93,7 @@ func NewMetadataService(svc EC2Metadata) (MetadataService, error) { return &Metadata{ InstanceID: doc.InstanceID, + InstanceType: doc.InstanceType, Region: doc.Region, AvailabilityZone: doc.AvailabilityZone, }, nil diff --git a/pkg/cloud/metadata_test.go b/pkg/cloud/metadata_test.go index 2915ddb94d..df333e63c2 100644 --- a/pkg/cloud/metadata_test.go +++ b/pkg/cloud/metadata_test.go @@ -27,6 +27,7 @@ import ( var ( stdInstanceID = "instance-1" + stdInstanceType = "t2.medium" stdRegion = "instance-1" stdAvailabilityZone = "az-1" ) @@ -44,6 +45,7 @@ func TestNewMetadataService(t *testing.T) { isAvailable: true, identityDocument: ec2metadata.EC2InstanceIdentityDocument{ InstanceID: stdInstanceID, + InstanceType: stdInstanceType, Region: stdRegion, AvailabilityZone: stdAvailabilityZone, }, @@ -54,6 +56,7 @@ func TestNewMetadataService(t *testing.T) { isAvailable: false, identityDocument: ec2metadata.EC2InstanceIdentityDocument{ InstanceID: stdInstanceID, + InstanceType: stdInstanceType, Region: stdRegion, AvailabilityZone: stdAvailabilityZone, }, @@ -64,6 +67,7 @@ func TestNewMetadataService(t *testing.T) { isAvailable: true, identityDocument: ec2metadata.EC2InstanceIdentityDocument{ InstanceID: stdInstanceID, + InstanceType: stdInstanceType, Region: stdRegion, AvailabilityZone: stdAvailabilityZone, }, @@ -75,6 +79,7 @@ func TestNewMetadataService(t *testing.T) { isPartial: true, identityDocument: ec2metadata.EC2InstanceIdentityDocument{ InstanceID: "", + InstanceType: stdInstanceType, Region: stdRegion, AvailabilityZone: stdAvailabilityZone, }, @@ -86,6 +91,7 @@ func TestNewMetadataService(t *testing.T) { isPartial: true, identityDocument: ec2metadata.EC2InstanceIdentityDocument{ InstanceID: stdInstanceID, + InstanceType: stdInstanceType, Region: "", AvailabilityZone: stdAvailabilityZone, }, @@ -97,6 +103,7 @@ func TestNewMetadataService(t *testing.T) { isPartial: true, identityDocument: ec2metadata.EC2InstanceIdentityDocument{ InstanceID: stdInstanceID, + InstanceType: stdInstanceType, Region: stdRegion, AvailabilityZone: "", }, @@ -124,6 +131,10 @@ func TestNewMetadataService(t *testing.T) { t.Fatalf("GetInstanceID() failed: expected %v, got %v", tc.identityDocument.InstanceID, m.GetInstanceID()) } + if m.GetInstanceType() != tc.identityDocument.InstanceType { + t.Fatalf("GetInstanceType() failed: expected %v, got %v", tc.identityDocument.InstanceType, m.GetInstanceType()) + } + if m.GetRegion() != tc.identityDocument.Region { t.Fatalf("GetRegion() failed: expected %v, got %v", tc.identityDocument.Region, m.GetRegion()) } diff --git a/pkg/driver/mocks/mock_metadata_service.go b/pkg/driver/mocks/mock_metadata_service.go index dc9ef67c20..dba05fda50 100644 --- a/pkg/driver/mocks/mock_metadata_service.go +++ b/pkg/driver/mocks/mock_metadata_service.go @@ -60,6 +60,20 @@ func (mr *MockMetadataServiceMockRecorder) GetInstanceID() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstanceID", reflect.TypeOf((*MockMetadataService)(nil).GetInstanceID)) } +// GetInstanceType mocks base method +func (m *MockMetadataService) GetInstanceType() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetInstanceType") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetInstanceType indicates an expected call of GetInstanceType +func (mr *MockMetadataServiceMockRecorder) GetInstanceType() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstanceType", reflect.TypeOf((*MockMetadataService)(nil).GetInstanceType)) +} + // GetRegion mocks base method func (m *MockMetadataService) GetRegion() string { m.ctrl.T.Helper() diff --git a/tests/e2e/pre_provsioning.go b/tests/e2e/pre_provsioning.go index d840757100..0bd72a664e 100644 --- a/tests/e2e/pre_provsioning.go +++ b/tests/e2e/pre_provsioning.go @@ -17,16 +17,17 @@ package e2e import ( "context" "fmt" + "math/rand" + "os" + "strings" + awscloud "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud" "github.com/kubernetes-sigs/aws-ebs-csi-driver/tests/e2e/driver" "github.com/kubernetes-sigs/aws-ebs-csi-driver/tests/e2e/testsuites" . "github.com/onsi/ginkgo" - "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" clientset "k8s.io/client-go/kubernetes" "k8s.io/kubernetes/test/e2e/framework" - "math/rand" - "os" - "strings" ebscsidriver "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver" ) @@ -53,6 +54,10 @@ func (s e2eMetdataService) GetInstanceID() string { return "" } +func (s e2eMetdataService) GetInstanceType() string { + return "" +} + func (s e2eMetdataService) GetAvailabilityZone() string { return s.availabilityZone } From 6201e3bb27e702120c2c592657432a5de75a195d Mon Sep 17 00:00:00 2001 From: Fabio Bertinatto Date: Sun, 5 May 2019 21:17:52 +0200 Subject: [PATCH 2/2] Add max number of volumes that can be attached to an instance --- pkg/driver/node.go | 19 +++++++++++ pkg/driver/node_test.go | 74 ++++++++++++++++++++++++++++------------- 2 files changed, 70 insertions(+), 23 deletions(-) diff --git a/pkg/driver/node.go b/pkg/driver/node.go index 921eb8efc4..d6f70dda2f 100644 --- a/pkg/driver/node.go +++ b/pkg/driver/node.go @@ -21,6 +21,7 @@ import ( "fmt" "os" "path/filepath" + "regexp" "strings" csi "github.com/container-storage-interface/spec/lib/go/csi" @@ -44,6 +45,13 @@ const ( // default file system type to be used when it is not provided defaultFsType = FSTypeExt4 + + // defaultMaxEBSVolumes is the maximum number of volumes that an AWS instance can have attached. + // More info at https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/volume_limits.html + defaultMaxEBSVolumes = 39 + + // defaultMaxEBSNitroVolumes is the limit of volumes for some smaller instances, like c5 and m5. + defaultMaxEBSNitroVolumes = 25 ) var ( @@ -316,6 +324,7 @@ func (d *nodeService) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoReque return &csi.NodeGetInfoResponse{ NodeId: m.GetInstanceID(), + MaxVolumesPerNode: d.getVolumesLimit(), AccessibleTopology: topology, }, nil } @@ -464,6 +473,16 @@ func findNvmeVolume(findName string) (device string, err error) { return resolved, nil } +// getVolumesLimit returns the limit of volumes that the node supports +func (d *nodeService) getVolumesLimit() int64 { + ebsNitroInstanceTypeRegex := "^[cmr]5.*|t3|z1d" + instanceType := d.metadata.GetInstanceType() + if ok, _ := regexp.MatchString(ebsNitroInstanceTypeRegex, instanceType); ok { + return defaultMaxEBSNitroVolumes + } + return defaultMaxEBSVolumes +} + func newSafeMounter() *mount.SafeFormatAndMount { return &mount.SafeFormatAndMount{ Interface: mount.New(""), diff --git a/pkg/driver/node_test.go b/pkg/driver/node_test.go index 73732d21dd..31f94a181a 100644 --- a/pkg/driver/node_test.go +++ b/pkg/driver/node_test.go @@ -755,34 +755,62 @@ func TestNodeGetCapabilities(t *testing.T) { } func TestNodeGetInfo(t *testing.T) { - mockCtl := gomock.NewController(t) - defer mockCtl.Finish() + testCases := []struct { + name string + instanceID string + instanceType string + availabilityZone string + expMaxVolumes int64 + }{ + { + name: "success normal", + instanceID: "i-123456789abcdef01", + instanceType: "t2.medium", + availabilityZone: "us-west-2b", + expMaxVolumes: 39, + }, + { + name: "success normal with NVMe", + instanceID: "i-123456789abcdef01", + instanceType: "m5d.large", + availabilityZone: "us-west-2b", + expMaxVolumes: 25, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() - mockMetadata := mocks.NewMockMetadataService(mockCtl) - mockMetadata.EXPECT().GetInstanceID().Return(expInstanceId) - mockMetadata.EXPECT().GetAvailabilityZone().Return(expZone).Times(2) + mockMetadata := mocks.NewMockMetadataService(mockCtl) + mockMetadata.EXPECT().GetInstanceID().Return(tc.instanceID) + mockMetadata.EXPECT().GetInstanceType().Return(tc.instanceType) + mockMetadata.EXPECT().GetAvailabilityZone().Return(tc.availabilityZone) - req := &csi.NodeGetInfoRequest{} + awsDriver := newTestNodeService(mockMetadata, NewFakeMounter()) - awsDriver := newTestNodeService(mockMetadata, NewFakeMounter()) + resp, err := awsDriver.NodeGetInfo(context.TODO(), &csi.NodeGetInfoRequest{}) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + t.Fatalf("Expected nil error, got %d message %s", srvErr.Code(), srvErr.Message()) + } - expResp := &csi.NodeGetInfoResponse{ - NodeId: expInstanceId, - AccessibleTopology: &csi.Topology{ - Segments: map[string]string{TopologyKey: mockMetadata.GetAvailabilityZone()}, - }, - } + if resp.GetNodeId() != tc.instanceID { + t.Fatalf("Expected node ID %q, got %q", tc.instanceID, resp.GetNodeId()) + } - resp, err := awsDriver.NodeGetInfo(context.TODO(), req) - if err != nil { - srvErr, ok := status.FromError(err) - if !ok { - t.Fatalf("Could not get error status code from error: %v", srvErr) - } - t.Fatalf("Expected nil error, got %d message %s", srvErr.Code(), srvErr.Message()) - } - if !reflect.DeepEqual(expResp, resp) { - t.Fatalf("Expected response {%+v}, got {%+v}", expResp, resp) + at := resp.GetAccessibleTopology() + if at.Segments[TopologyKey] != tc.availabilityZone { + t.Fatalf("Expected topology %q, got %q", tc.availabilityZone, at.Segments[TopologyKey]) + } + + if resp.GetMaxVolumesPerNode() != tc.expMaxVolumes { + t.Fatalf("Expected %d max volumes per node, got %d", tc.expMaxVolumes, resp.GetMaxVolumesPerNode()) + } + }) } }