From cbc93742a6561e3bb60df09351d1f07d4af854a5 Mon Sep 17 00:00:00 2001 From: Anish Ramasekar Date: Tue, 2 Jul 2019 15:36:02 -0700 Subject: [PATCH] nmi retry and sync retry loop update retry --- cmd/mic/main.go | 20 ++++++--- pkg/crd/crd.go | 2 +- pkg/mic/mic.go | 40 ++++++++++-------- pkg/mic/mic_test.go | 88 +++++++++++++++++++++++++++++++++++++--- pkg/nmi/server/server.go | 55 +++++++++++++++++++++---- 5 files changed, 168 insertions(+), 37 deletions(-) diff --git a/cmd/mic/main.go b/cmd/mic/main.go index 2ff283cb1..36bfc8a14 100644 --- a/cmd/mic/main.go +++ b/cmd/mic/main.go @@ -3,6 +3,7 @@ package main import ( "flag" "os" + "time" "github.com/Azure/aad-pod-identity/pkg/mic" "github.com/Azure/aad-pod-identity/version" @@ -12,10 +13,11 @@ import ( ) var ( - kubeconfig string - cloudconfig string - forceNamespaced bool - versionInfo bool + kubeconfig string + cloudconfig string + forceNamespaced bool + versionInfo bool + syncRetryInterval string ) func main() { @@ -24,6 +26,8 @@ func main() { flag.StringVar(&cloudconfig, "cloudconfig", "", "Path to cloud config e.g. Azure.json file") flag.BoolVar(&forceNamespaced, "forceNamespaced", false, "Forces namespaced identities, binding, and assignment") flag.BoolVar(&versionInfo, "version", false, "Prints the version information") + flag.StringVar(&syncRetryInterval, "syncRetryInterval", "3600", "The interval in seconds at which sync loop should periodically check for errors and reconcile.") + flag.Parse() if versionInfo { version.PrintVersionAndExit() @@ -43,7 +47,13 @@ func main() { } forceNamespaced = forceNamespaced || "true" == os.Getenv("FORCENAMESPACED") - micClient, err := mic.NewMICClient(cloudconfig, config, forceNamespaced) + + syncRetryDuration, err := time.ParseDuration(syncRetryInterval) + if err != nil { + glog.Fatalf("Could not read syncRetryInterval. Error %+v", err) + } + + micClient, err := mic.NewMICClient(cloudconfig, config, forceNamespaced, syncRetryDuration*time.Second) if err != nil { glog.Fatalf("Could not get the MIC client: %+v", err) } diff --git a/pkg/crd/crd.go b/pkg/crd/crd.go index 7a792ddf7..0598285c7 100644 --- a/pkg/crd/crd.go +++ b/pkg/crd/crd.go @@ -267,7 +267,7 @@ func (c *Client) ListPodIds(podns, podname string) (*[]aadpodid.AzureIdentity, e var matchedIds []aadpodid.AzureIdentity for _, v := range azAssignedIDList.(*aadpodid.AzureAssignedIdentityList).Items { - if v.Spec.Pod == podname && v.Spec.PodNamespace == podns { + if v.Spec.Pod == podname && v.Spec.PodNamespace == podns && v.Status.Status == "Assigned" { matchedIds = append(matchedIds, *v.Spec.AzureIdentityRef) } } diff --git a/pkg/mic/mic.go b/pkg/mic/mic.go index 3825d7ab1..3b1844089 100644 --- a/pkg/mic/mic.go +++ b/pkg/mic/mic.go @@ -43,13 +43,14 @@ type NodeGetter interface { // Client has the required pointers to talk to the api server // and interact with the CRD related datastructure. type Client struct { - CRDClient crd.ClientInt - CloudClient cloudprovider.ClientInt - PodClient pod.ClientInt - EventRecorder record.EventRecorder - EventChannel chan aadpodid.EventType - NodeClient NodeGetter - IsNamespaced bool + CRDClient crd.ClientInt + CloudClient cloudprovider.ClientInt + PodClient pod.ClientInt + EventRecorder record.EventRecorder + EventChannel chan aadpodid.EventType + NodeClient NodeGetter + IsNamespaced bool + syncRetryInterval time.Duration syncing int32 // protect against conucrrent sync's } @@ -68,7 +69,7 @@ type trackUserAssignedMSIIds struct { } // NewMICClient returnes new mic client -func NewMICClient(cloudconfig string, config *rest.Config, isNamespaced bool) (*Client, error) { +func NewMICClient(cloudconfig string, config *rest.Config, isNamespaced bool, syncRetryInterval time.Duration) (*Client, error) { glog.Infof("Starting to create the pod identity client. Version: %v. Build date: %v", version.MICVersion, version.BuildDate) clientSet := kubernetes.NewForConfigOrDie(config) @@ -96,13 +97,14 @@ func NewMICClient(cloudconfig string, config *rest.Config, isNamespaced bool) (* recorder := eventBroadcaster.NewRecorder(scheme.Scheme, corev1.EventSource{Component: aadpodid.CRDGroup}) return &Client{ - CRDClient: crdClient, - CloudClient: cloudClient, - PodClient: podClient, - EventRecorder: recorder, - EventChannel: eventCh, - NodeClient: &NodeClient{informer.Core().V1().Nodes()}, - IsNamespaced: isNamespaced, + CRDClient: crdClient, + CloudClient: cloudClient, + PodClient: podClient, + EventRecorder: recorder, + EventChannel: eventCh, + NodeClient: &NodeClient{informer.Core().V1().Nodes()}, + IsNamespaced: isNamespaced, + syncRetryInterval: syncRetryInterval, }, nil } @@ -152,6 +154,9 @@ func (c *Client) Sync(exit <-chan struct{}) { } defer c.setStopped() + ticker := time.NewTicker(c.syncRetryInterval) + defer ticker.Stop() + glog.Info("Sync thread started.") var event aadpodid.EventType for { @@ -159,13 +164,16 @@ func (c *Client) Sync(exit <-chan struct{}) { case <-exit: return case event = <-c.EventChannel: + glog.V(6).Infof("Received event: %v", event) + case <-ticker.C: + glog.V(6).Infof("Running sync retry loop") } stats.Init() // This is the only place where the AzureAssignedIdentity creation is initiated. begin := time.Now() workDone := false - glog.V(6).Infof("Received event: %v", event) + // List all pods in all namespaces systemTime := time.Now() listPods, err := c.PodClient.GetPods() diff --git a/pkg/mic/mic_test.go b/pkg/mic/mic_test.go index bf3ae0423..43b0c65d5 100644 --- a/pkg/mic/mic_test.go +++ b/pkg/mic/mic_test.go @@ -556,12 +556,13 @@ func (c *TestEventRecorder) AnnotatedEventf(object runtime.Object, annotations m func NewMICTestClient(eventCh chan aadpodid.EventType, cpClient *TestCloudClient, crdClient *TestCrdClient, podClient *TestPodClient, nodeClient *TestNodeClient, eventRecorder *TestEventRecorder) *TestMICClient { realMICClient := &Client{ - CloudClient: cpClient, - CRDClient: crdClient, - EventRecorder: eventRecorder, - PodClient: podClient, - EventChannel: eventCh, - NodeClient: nodeClient, + CloudClient: cpClient, + CRDClient: crdClient, + EventRecorder: eventRecorder, + PodClient: podClient, + EventChannel: eventCh, + NodeClient: nodeClient, + syncRetryInterval: 120 * time.Second, } return &TestMICClient{ @@ -1079,6 +1080,81 @@ func TestMICStateFlow(t *testing.T) { } } +func TestSyncRetryLoop(t *testing.T) { + eventCh := make(chan aadpodid.EventType, 100) + cloudClient := NewTestCloudClient(config.AzureConfig{}) + crdClient := NewTestCrdClient(nil) + podClient := NewTestPodClient() + nodeClient := NewTestNodeClient() + var evtRecorder TestEventRecorder + evtRecorder.lastEvent = new(LastEvent) + evtRecorder.eventChannel = make(chan bool, 100) + + micClient := NewMICTestClient(eventCh, cloudClient, crdClient, podClient, nodeClient, &evtRecorder) + micClient.syncRetryInterval = 10 * time.Second + + // Add a pod, identity and binding. + crdClient.CreateID("test-id1", aadpodid.UserAssignedMSI, "test-user-msi-resourceid", "test-user-msi-clientid", nil, "", "", "") + crdClient.CreateBinding("testbinding1", "test-id1", "test-select1") + + nodeClient.AddNode("test-node1") + podClient.AddPod("test-pod1", "default", "test-node1", "test-select1") + + eventCh <- aadpodid.PodCreated + defer micClient.testRunSync()(t) + + if !evtRecorder.WaitForEvents(1) { + t.Fatalf("Timeout waiting for mic sync cycles") + } + listAssignedIDs, err := crdClient.ListAssignedIDs() + if err != nil { + glog.Error(err) + t.Errorf("list assigned failed") + } + if !(len(*listAssignedIDs) == 1) { + t.Fatalf("expected assigned identities len: %d, got: %d", 1, len(*listAssignedIDs)) + } + if !((*listAssignedIDs)[0].Status.Status == IdentityAssigned) { + t.Fatalf("expected status to be %s, got: %s", IdentityCreated, (*listAssignedIDs)[0].Status.Status) + } + + // delete the pod, simulate failure in cloud calls on trying to un-assign identity from node + podClient.DeletePod("test-pod1", "default") + cloudClient.SetError(errors.New("error removing identity from node")) + cloudClient.testVMClient.identity = &compute.VirtualMachineIdentity{IdentityIds: &[]string{"test-user-msi-resourceid"}} + + eventCh <- aadpodid.PodDeleted + if !evtRecorder.WaitForEvents(1) { + t.Fatalf("Timeout waiting for mic sync cycles") + } + + listAssignedIDs, err = crdClient.ListAssignedIDs() + if err != nil { + glog.Error(err) + t.Errorf("list assigned failed") + } + if !(len(*listAssignedIDs) == 1) { + t.Fatalf("expected assigned identities len: %d, got: %d", 1, len(*listAssignedIDs)) + } + if !((*listAssignedIDs)[0].Status.Status == IdentityAssigned) { + t.Fatalf("expected status to be %s, got: %s", IdentityAssigned, (*listAssignedIDs)[0].Status.Status) + } + cloudClient.UnSetError() + + if !evtRecorder.WaitForEvents(1) { + t.Fatalf("Timeout waiting for mic sync retry cycle") + } + + listAssignedIDs, err = crdClient.ListAssignedIDs() + if err != nil { + glog.Error(err) + t.Errorf("list assigned failed") + } + if !(len(*listAssignedIDs) == 0) { + t.Fatalf("expected assigned identities len: %d, got: %d", 0, len(*listAssignedIDs)) + } +} + func TestSyncExit(t *testing.T) { eventCh := make(chan aadpodid.EventType) cloudClient := NewTestCloudClient(config.AzureConfig{VMType: "vmss"}) diff --git a/pkg/nmi/server/server.go b/pkg/nmi/server/server.go index 001a016dd..e7e71cbcc 100644 --- a/pkg/nmi/server/server.go +++ b/pkg/nmi/server/server.go @@ -27,6 +27,8 @@ import ( const ( iptableUpdateTimeIntervalInSeconds = 60 localhost = "127.0.0.1" + listPodIDsRetryAttempts = 10 + listPodIDsRetryIntervalInSeconds = 6 ) // Server encapsulates all of the parameters necessary for starting up @@ -153,6 +155,8 @@ func (fn appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) hostHandler(logger *log.Entry, w http.ResponseWriter, r *http.Request) { hostIP := parseRemoteAddr(r.RemoteAddr) + rqClientID, rqResource := parseRequestClientIDAndResource(r) + if hostIP != localhost { msg := "request remote address is not from a host" logger.Error(msg) @@ -167,15 +171,15 @@ func (s *Server) hostHandler(logger *log.Entry, w http.ResponseWriter, r *http.R return } - podIDs, err := s.KubeClient.ListPodIds(podns, podname) - if err != nil || len(*podIDs) == 0 { + podIDs, err := listPodIDsWithRetry(s.KubeClient, logger, podns, podname, rqClientID, listPodIDsRetryAttempts) + if err != nil { msg := fmt.Sprintf("no AzureAssignedIdentity found for pod:%s/%s", podns, podname) logger.Errorf("%s, %+v", msg, err) - http.Error(w, msg, http.StatusForbidden) + http.Error(w, msg, http.StatusNotFound) return } - // filter out if we are in namesoaced mode + // filter out if we are in namespaced mode filterPodIdentities := []aadpodid.AzureIdentity{} for _, val := range *(podIDs) { if s.IsNamespaced || aadpodid.IsNamespacedIdentity(&val) { @@ -193,7 +197,6 @@ func (s *Server) hostHandler(logger *log.Entry, w http.ResponseWriter, r *http.R } } podIDs = &filterPodIdentities - rqClientID, rqResource := parseRequestClientIDAndResource(r) token, clientID, err := getTokenForMatchingID(s.KubeClient, logger, rqClientID, rqResource, podIDs) if err != nil { logger.Errorf("failed to get service principal token for pod:%s/%s, %+v", podns, podname, err) @@ -220,6 +223,8 @@ func (s *Server) hostHandler(logger *log.Entry, w http.ResponseWriter, r *http.R // configured id. func (s *Server) msiHandler(logger *log.Entry, w http.ResponseWriter, r *http.Request) { podIP := parseRemoteAddr(r.RemoteAddr) + rqClientID, rqResource := parseRequestClientIDAndResource(r) + if podIP == "" { msg := "request remote address is empty" logger.Error(msg) @@ -232,14 +237,15 @@ func (s *Server) msiHandler(logger *log.Entry, w http.ResponseWriter, r *http.Re http.Error(w, err.Error(), http.StatusInternalServerError) return } - podIDs, err := s.KubeClient.ListPodIds(podns, podname) - if err != nil || len(*podIDs) == 0 { + + podIDs, err := listPodIDsWithRetry(s.KubeClient, logger, podns, podname, rqClientID, listPodIDsRetryAttempts) + if err != nil { msg := fmt.Sprintf("no AzureAssignedIdentity found for pod:%s/%s", podns, podname) logger.Errorf("%s, %+v", msg, err) - http.Error(w, msg, http.StatusForbidden) + http.Error(w, msg, http.StatusNotFound) return } - rqClientID, rqResource := parseRequestClientIDAndResource(r) + token, _, err := getTokenForMatchingID(s.KubeClient, logger, rqClientID, rqResource, podIDs) if err != nil { logger.Errorf("failed to get service principal token for pod:%s/%s, %+v", podns, podname, err) @@ -380,3 +386,34 @@ func handleTermination() { log.Infof("Exiting with %v", exitCode) os.Exit(exitCode) } + +func listPodIDsWithRetry(kubeClient k8s.Client, logger *log.Entry, podns, podname, rqClientID string, maxAttempts int) (*[]aadpodid.AzureIdentity, error) { + attempt := 0 + var err error + var podIDs *[]aadpodid.AzureIdentity + + for { + podIDs, err = kubeClient.ListPodIds(podns, podname) + if err == nil && len(*podIDs) != 0 { + if len(rqClientID) == 0 { + return podIDs, nil + } + // if client id exists in request, we need to ensure the identity with this client + // exists and is in Assigned state + for _, podID := range *podIDs { + if strings.EqualFold(rqClientID, podID.Spec.ClientID) { + return podIDs, nil + } + } + } + + if attempt >= maxAttempts { + break + } + + attempt++ + logger.Warningf("failed to get assigned ids for pod:%s/%s, retrying attempt: %d", podns, podname, attempt) + time.Sleep(listPodIDsRetryIntervalInSeconds * time.Second) + } + return nil, fmt.Errorf("getting assigned identities for pod %s/%s failed after %d attempts. Error: %v", podns, podname, maxAttempts, err) +}