diff --git a/pkg/awsutils/awsutils.go b/pkg/awsutils/awsutils.go index 8039fc86c6..eec130830c 100644 --- a/pkg/awsutils/awsutils.go +++ b/pkg/awsutils/awsutils.go @@ -22,7 +22,6 @@ import ( "os" "regexp" "strings" - "sync" "time" "github.com/aws/amazon-vpc-cni-k8s/pkg/utils/logger" @@ -137,7 +136,7 @@ type APIs interface { DeallocIPAddresses(eniID string, ips []string) error // GetVPCIPv4CIDRs returns VPC's CIDRs from instance metadata - GetVPCIPv4CIDRs() []string + GetVPCIPv4CIDRs() ([]string, error) // GetLocalIPv4 returns the primary IP address on the primary ENI interface GetLocalIPv4() net.IP @@ -158,13 +157,10 @@ type APIs interface { // EC2InstanceMetadataCache caches instance metadata type EC2InstanceMetadataCache struct { // metadata info - securityGroups StringSet subnetID string localIPv4 net.IP instanceID string instanceType string - vpcIPv4CIDR string - vpcIPv4CIDRs StringSet primaryENI string primaryENImac string availabilityZone string @@ -217,34 +213,6 @@ func prometheusRegister() { } } -//StringSet is a set of strings -type StringSet struct { - sync.RWMutex - data sets.String -} - -func (ss *StringSet) SortedList() []string { - ss.RLock() - defer ss.RUnlock() - // sets.String.List() returns a sorted list - return ss.data.List() -} - -func (ss *StringSet) Set(items []string) { - ss.Lock() - defer ss.Unlock() - ss.data = sets.NewString(items...) -} - -func (ss *StringSet) Difference(other *StringSet) *StringSet { - ss.RLock() - other.RLock() - defer ss.RUnlock() - defer other.RUnlock() - //example: s1 = {a1, a2, a3} s2 = {a1, a2, a4, a5} s1.Difference(s2) = {a3} s2.Difference(s1) = {a4, a5} - return &StringSet{data: ss.data.Difference(other.data)} -} - type instrumentedIMDS struct { EC2MetadataIface } @@ -368,83 +336,6 @@ func (cache *EC2InstanceMetadataCache) initWithEC2Metadata(ctx context.Context) } log.Debugf("Found subnet-id: %s ", cache.subnetID) - // retrieve security groups - err = cache.refreshSGIDs(mac) - if err != nil { - return err - } - - // retrieve VPC IPv4 CIDR blocks - err = cache.refreshVPCIPv4CIDRs(mac) - if err != nil { - return err - } - - // Refresh security groups and VPC CIDR blocks in the background - // Ignoring errors since we will retry in 30s - go wait.Forever(func() { _ = cache.refreshSGIDs(mac) }, 30*time.Second) - go wait.Forever(func() { _ = cache.refreshVPCIPv4CIDRs(mac) }, 30*time.Second) - - // We use the ctx here for testing, since we spawn go-routines above which will run forever. - select { - case <-ctx.Done(): - return nil - default: - } - return nil -} - -// refreshSGIDs retrieves security groups -func (cache *EC2InstanceMetadataCache) refreshSGIDs(mac string) error { - ctx := context.TODO() - - sgIDs, err := cache.imds.GetSecurityGroupIDs(ctx, mac) - if err != nil { - return err - } - - newSGs := StringSet{} - newSGs.Set(sgIDs) - addedSGs := newSGs.Difference(&cache.securityGroups) - deletedSGs := cache.securityGroups.Difference(&newSGs) - - for _, sg := range addedSGs.SortedList() { - log.Infof("Found %s, added to ipamd cache", sg) - } - for _, sg := range deletedSGs.SortedList() { - log.Infof("Removed %s from ipamd cache", sg) - } - cache.securityGroups.Set(sgIDs) - return nil -} - -// refreshVPCIPv4CIDRs retrieves VPC IPv4 CIDR blocks -func (cache *EC2InstanceMetadataCache) refreshVPCIPv4CIDRs(mac string) error { - ctx := context.TODO() - - ipnets, err := cache.imds.GetVPCIPv4CIDRBlocks(ctx, mac) - if err != nil { - return err - } - - // TODO: keep as net.IPNet and remove this round-trip to/from string - vpcIPv4CIDRs := make([]string, len(ipnets)) - for i, ipnet := range ipnets { - vpcIPv4CIDRs[i] = ipnet.String() - } - - newVpcIPv4CIDRs := StringSet{} - newVpcIPv4CIDRs.Set(vpcIPv4CIDRs) - addedVpcIPv4CIDRs := newVpcIPv4CIDRs.Difference(&cache.vpcIPv4CIDRs) - deletedVpcIPv4CIDRs := cache.vpcIPv4CIDRs.Difference(&newVpcIPv4CIDRs) - - for _, vpcIPv4CIDR := range addedVpcIPv4CIDRs.SortedList() { - log.Infof("Found %s, added to ipamd cache", vpcIPv4CIDR) - } - for _, vpcIPv4CIDR := range deletedVpcIPv4CIDRs.SortedList() { - log.Infof("Removed %s from ipamd cache", vpcIPv4CIDR) - } - cache.vpcIPv4CIDRs.Set(vpcIPv4CIDRs) return nil } @@ -629,11 +520,11 @@ func (cache *EC2InstanceMetadataCache) attachENI(eniID string) (string, error) { // return ENI id, error func (cache *EC2InstanceMetadataCache) createENI(useCustomCfg bool, sg []*string, subnet string) (string, error) { + ctx := context.TODO() + eniDescription := eniDescriptionPrefix + cache.instanceID input := &ec2.CreateNetworkInterfaceInput{ Description: aws.String(eniDescription), - Groups: aws.StringSlice(cache.securityGroups.SortedList()), - SubnetId: aws.String(cache.subnetID), } if useCustomCfg { @@ -642,12 +533,14 @@ func (cache *EC2InstanceMetadataCache) createENI(useCustomCfg bool, sg []*string input.SubnetId = aws.String(subnet) } else { log.Info("Using same config as the primary interface for the new ENI") + sgIDs, err := cache.imds.GetSecurityGroupIDs(ctx, cache.primaryENImac) + if err != nil { + return "", err + } + input.Groups = aws.StringSlice(sgIDs) + input.SubnetId = aws.String(cache.subnetID) } - var sgs []string - for i := range input.Groups { - sgs = append(sgs, *input.Groups[i]) - } - log.Infof("Creating ENI with security groups: %v in subnet: %s", sgs, *input.SubnetId) + log.Infof("Creating ENI with security groups: %v in subnet: %s", aws.StringValueSlice(input.Groups), *input.SubnetId) start := time.Now() result, err := cache.ec2SVC.CreateNetworkInterfaceWithContext(context.Background(), input, userAgent) awsAPILatency.WithLabelValues("CreateNetworkInterface", fmt.Sprint(err != nil), awsReqStatus(err)).Observe(msSince(start)) @@ -1233,8 +1126,21 @@ func (cache *EC2InstanceMetadataCache) getFilteredListOfNetworkInterfaces() ([]* } // GetVPCIPv4CIDRs returns VPC CIDRs -func (cache *EC2InstanceMetadataCache) GetVPCIPv4CIDRs() []string { - return cache.vpcIPv4CIDRs.SortedList() +func (cache *EC2InstanceMetadataCache) GetVPCIPv4CIDRs() ([]string, error) { + ctx := context.TODO() + + ipnets, err := cache.imds.GetVPCIPv4CIDRBlocks(ctx, cache.primaryENImac) + if err != nil { + return nil, err + } + + // TODO: keep as net.IPNet and remove this round-trip to/from string + asStrs := make([]string, len(ipnets)) + for i, ipnet := range ipnets { + asStrs[i] = ipnet.String() + } + + return asStrs, nil } // GetLocalIPv4 returns the primary IP address on the primary interface diff --git a/pkg/awsutils/awsutils_test.go b/pkg/awsutils/awsutils_test.go index 4e8f9dc174..f072ceb79c 100644 --- a/pkg/awsutils/awsutils_test.go +++ b/pkg/awsutils/awsutils_test.go @@ -105,9 +105,7 @@ func TestInitWithEC2metadata(t *testing.T) { assert.Equal(t, ins.instanceID, instanceID) assert.Equal(t, ins.primaryENImac, primaryMAC) assert.Equal(t, ins.primaryENI, primaryeniID) - assert.Equal(t, len(ins.securityGroups.SortedList()), 2) assert.Equal(t, subnetID, ins.subnetID) - assert.Equal(t, len(ins.vpcIPv4CIDRs.SortedList()), 2) } } @@ -390,6 +388,7 @@ func TestAllocENI(t *testing.T) { ins := &EC2InstanceMetadataCache{ ec2SVC: mockEC2, imds: TypedIMDS{mockMetadata}, + primaryENImac: primaryMAC, } _, err := ins.AllocENI(false, nil, "") assert.NoError(t, err) @@ -423,6 +422,7 @@ func TestAllocENINoFreeDevice(t *testing.T) { ins := &EC2InstanceMetadataCache{ ec2SVC: mockEC2, imds: TypedIMDS{mockMetadata}, + primaryENImac: primaryMAC, } _, err := ins.AllocENI(false, nil, "") assert.Error(t, err) @@ -458,6 +458,7 @@ func TestAllocENIMaxReached(t *testing.T) { ins := &EC2InstanceMetadataCache{ ec2SVC: mockEC2, imds: TypedIMDS{mockMetadata}, + primaryENImac: primaryMAC, } _, err := ins.AllocENI(false, nil, "") assert.Error(t, err) diff --git a/pkg/awsutils/mocks/awsutils_mocks.go b/pkg/awsutils/mocks/awsutils_mocks.go index 4cf4531e3b..091e00e44b 100644 --- a/pkg/awsutils/mocks/awsutils_mocks.go +++ b/pkg/awsutils/mocks/awsutils_mocks.go @@ -240,11 +240,12 @@ func (mr *MockAPIsMockRecorder) GetPrimaryENImac() *gomock.Call { } // GetVPCIPv4CIDRs mocks base method -func (m *MockAPIs) GetVPCIPv4CIDRs() []string { +func (m *MockAPIs) GetVPCIPv4CIDRs() ([]string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetVPCIPv4CIDRs") ret0, _ := ret[0].([]string) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // GetVPCIPv4CIDRs indicates an expected call of GetVPCIPv4CIDRs diff --git a/pkg/ipamd/ipamd.go b/pkg/ipamd/ipamd.go index bc370ef67a..bb515d1c7f 100644 --- a/pkg/ipamd/ipamd.go +++ b/pkg/ipamd/ipamd.go @@ -17,7 +17,6 @@ import ( "fmt" "net" "os" - "reflect" "strconv" "strings" "sync" @@ -33,6 +32,7 @@ import ( "github.com/aws/aws-sdk-go/service/ec2" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" + "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/kubernetes" ) @@ -338,7 +338,10 @@ func (c *IPAMContext) nodeInit() error { return err } - vpcCIDRs := c.awsClient.GetVPCIPv4CIDRs() + vpcCIDRs, err := c.awsClient.GetVPCIPv4CIDRs() + if err != nil { + return err + } primaryIP := c.awsClient.GetLocalIPv4() err = c.networkClient.SetupHostNetwork(vpcCIDRs, c.awsClient.GetPrimaryENImac(), &primaryIP) if err != nil { @@ -388,6 +391,10 @@ func (c *IPAMContext) nodeInit() error { if err = c.configureIPRulesForPods(vpcCIDRs); err != nil { return err } + // Spawning updateCIDRsRulesOnChange go-routine + go wait.Forever(func() { + vpcCIDRs = c.updateCIDRsRulesOnChange(vpcCIDRs) + }, 30*time.Second) // For a new node, attach IPs increasedPool, err := c.tryAssignIPs() @@ -397,8 +404,6 @@ func (c *IPAMContext) nodeInit() error { return err } - // Spawning updateCIDRsRulesOnChange go-routine - go wait.Forever(func() { vpcCIDRs = c.updateCIDRsRulesOnChange(vpcCIDRs) }, 30*time.Second) return nil } @@ -423,10 +428,16 @@ func (c *IPAMContext) configureIPRulesForPods(pbVPCcidrs []string) error { return nil } -func (c *IPAMContext) updateCIDRsRulesOnChange(oldVPCCidrs []string) []string { - newVPCCIDRs := c.awsClient.GetVPCIPv4CIDRs() +func (c *IPAMContext) updateCIDRsRulesOnChange(oldVPCCIDRs []string) []string { + newVPCCIDRs, err := c.awsClient.GetVPCIPv4CIDRs() + if err != nil { + log.Warnf("skipping periodic update to VPC CIDRs due to error: %v", err) + return oldVPCCIDRs + } - if len(oldVPCCidrs) != len(newVPCCIDRs) || !reflect.DeepEqual(oldVPCCidrs, newVPCCIDRs) { + old := sets.NewString(oldVPCCIDRs...) + new := sets.NewString(newVPCCIDRs...) + if !old.Equal(new) { _ = c.configureIPRulesForPods(newVPCCIDRs) } return newVPCCIDRs diff --git a/pkg/ipamd/ipamd_test.go b/pkg/ipamd/ipamd_test.go index 1a07b3abc7..6fac34ef40 100644 --- a/pkg/ipamd/ipamd_test.go +++ b/pkg/ipamd/ipamd_test.go @@ -102,7 +102,7 @@ func TestNodeInit(t *testing.T) { m.awsutils.EXPECT().GetIPv4sFromEC2(eni2.ENIID).AnyTimes().Return(eni2.IPv4Addresses, nil) primaryIP := net.ParseIP(ipaddr01) - m.awsutils.EXPECT().GetVPCIPv4CIDRs().AnyTimes().Return(cidrs) + m.awsutils.EXPECT().GetVPCIPv4CIDRs().AnyTimes().Return(cidrs, nil) m.awsutils.EXPECT().GetPrimaryENImac().Return("") m.network.EXPECT().SetupHostNetwork(cidrs, "", &primaryIP).Return(nil) diff --git a/pkg/ipamd/rpc_handler.go b/pkg/ipamd/rpc_handler.go index 7d8a2cee47..81c51fcf8e 100644 --- a/pkg/ipamd/rpc_handler.go +++ b/pkg/ipamd/rpc_handler.go @@ -46,14 +46,10 @@ func (s *server) AddNetwork(ctx context.Context, in *rpc.AddNetworkRequest) (*rp log.Infof("Received AddNetwork for NS %s, Sandbox %s, ifname %s", in.Netns, in.ContainerID, in.IfName) - ipamKey := datastore.IPAMKey{ - ContainerID: in.ContainerID, - IfName: in.IfName, - NetworkName: in.NetworkName, + pbVPCcidrs, err := s.ipamContext.awsClient.GetVPCIPv4CIDRs() + if err != nil { + return nil, err } - addr, deviceNumber, err := s.ipamContext.dataStore.AssignPodIPv4Address(ipamKey) - - pbVPCcidrs := s.ipamContext.awsClient.GetVPCIPv4CIDRs() for _, cidr := range pbVPCcidrs { log.Debugf("VPC CIDR %s", cidr) } @@ -66,6 +62,13 @@ func (s *server) AddNetwork(ctx context.Context, in *rpc.AddNetworkRequest) (*rp } } + ipamKey := datastore.IPAMKey{ + ContainerID: in.ContainerID, + IfName: in.IfName, + NetworkName: in.NetworkName, + } + addr, deviceNumber, err := s.ipamContext.dataStore.AssignPodIPv4Address(ipamKey) + resp := rpc.AddNetworkReply{ Success: err == nil, IPv4Addr: addr, diff --git a/pkg/ipamd/rpc_handler_test.go b/pkg/ipamd/rpc_handler_test.go index fd04dfa32b..5da4af52c4 100644 --- a/pkg/ipamd/rpc_handler_test.go +++ b/pkg/ipamd/rpc_handler_test.go @@ -68,18 +68,19 @@ func TestServer_AddNetwork(t *testing.T) { }, } for _, tc := range testCases { - m.awsutils.EXPECT().GetVPCIPv4CIDRs().Return(tc.vpcCIDRs) + m.awsutils.EXPECT().GetVPCIPv4CIDRs().Return(tc.vpcCIDRs, nil) m.network.EXPECT().UseExternalSNAT().Return(tc.useExternalSNAT) if !tc.useExternalSNAT { m.network.EXPECT().GetExcludeSNATCIDRs().Return(tc.snatExclusionCIDRs) } addNetworkReply, err := rpcServer.AddNetwork(context.TODO(), addNetworkRequest) - assert.NoError(t, err, tc.name) + if assert.NoError(t, err, tc.name) { - assert.Equal(t, tc.useExternalSNAT, addNetworkReply.UseExternalSNAT, tc.name) + assert.Equal(t, tc.useExternalSNAT, addNetworkReply.UseExternalSNAT, tc.name) - expectedCIDRs := append([]string{vpcCIDR}, tc.snatExclusionCIDRs...) - assert.Equal(t, expectedCIDRs, addNetworkReply.VPCcidrs, tc.name) + expectedCIDRs := append([]string{vpcCIDR}, tc.snatExclusionCIDRs...) + assert.Equal(t, expectedCIDRs, addNetworkReply.VPCcidrs, tc.name) + } } }