From d05af768460f01b0d71f250a4200d9a2a8599cca Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Mon, 11 May 2020 18:24:42 -0700 Subject: [PATCH 1/6] serivceregistration: refactor service registration logic to run later --- command/server.go | 23 ++- .../consul/consul_service_registration.go | 13 +- .../consul_service_registration_test.go | 30 +-- .../kubernetes/client/client.go | 37 ++-- .../kubernetes/client/client_test.go | 2 +- .../kubernetes/client/cmd/kubeclient/main.go | 84 +++++--- .../kubernetes/retry_handler.go | 179 +++++++++++++++--- .../kubernetes/retry_handler_test.go | 85 +++++---- .../kubernetes/service_registration.go | 104 ++-------- .../kubernetes/service_registration_test.go | 4 +- serviceregistration/service_registration.go | 4 +- vault/core.go | 2 +- 12 files changed, 329 insertions(+), 238 deletions(-) diff --git a/command/server.go b/command/server.go index 8494ec5e8c10..8d0fbcc3036a 100644 --- a/command/server.go +++ b/command/server.go @@ -966,9 +966,6 @@ func (c *ServerCommand) Run(args []string) int { return 1 } - // Instantiate the wait group - c.WaitGroup = &sync.WaitGroup{} - // Initialize the Service Discovery, if there is one var configSR sr.ServiceRegistration if config.ServiceRegistration != nil { @@ -990,15 +987,11 @@ func (c *ServerCommand) Run(args []string) int { IsActive: false, IsPerformanceStandby: false, } - configSR, err = sdFactory(config.ServiceRegistration.Config, namedSDLogger, state, config.Storage.RedirectAddr) + configSR, err = sdFactory(config.ServiceRegistration.Config, namedSDLogger, state) if err != nil { c.UI.Error(fmt.Sprintf("Error initializing service_registration of type %s: %s", config.ServiceRegistration.Type, err)) return 1 } - if err := configSR.Run(c.ShutdownCh, c.WaitGroup); err != nil { - c.UI.Error(fmt.Sprintf("Error running service_registration of type %s: %s", config.ServiceRegistration.Type, err)) - return 1 - } } infoKeys := make([]string, 0, 10) @@ -1311,7 +1304,7 @@ CLUSTER_SYNTHESIS_COMPLETE: // If ServiceRegistration is configured, then the backend must support HA isBackendHA := coreConfig.HAPhysical != nil && coreConfig.HAPhysical.HAEnabled() - if !c.flagDev && (coreConfig.ServiceRegistration != nil) && !isBackendHA { + if !c.flagDev && (coreConfig.GetServiceRegistration() != nil) && !isBackendHA { c.UI.Output("service_registration is configured, but storage does not support HA") return 1 } @@ -1578,6 +1571,18 @@ CLUSTER_SYNTHESIS_COMPLETE: } // Perform initialization of HTTP server after the verifyOnly check. + + // Instantiate the wait group + c.WaitGroup = &sync.WaitGroup{} + + // If service discovery is available, run service discovery + if sd := coreConfig.GetServiceRegistration(); sd != nil { + if err := configSR.Run(c.ShutdownCh, c.WaitGroup, coreConfig.RedirectAddr); err != nil { + c.UI.Error(fmt.Sprintf("Error running service_registration of type %s: %s", config.ServiceRegistration.Type, err)) + return 1 + } + } + // If we're in Dev mode, then initialize the core if c.flagDev && !c.flagDevSkipInit { init, err := c.enableDev(core, coreConfig) diff --git a/serviceregistration/consul/consul_service_registration.go b/serviceregistration/consul/consul_service_registration.go index ccf010d2f935..249916224615 100644 --- a/serviceregistration/consul/consul_service_registration.go +++ b/serviceregistration/consul/consul_service_registration.go @@ -68,7 +68,6 @@ type serviceRegistration struct { serviceAddress *string disableRegistration bool checkTimeout time.Duration - redirectAddr string notifyActiveCh chan struct{} notifySealedCh chan struct{} @@ -78,8 +77,7 @@ type serviceRegistration struct { } // NewConsulServiceRegistration constructs a Consul-based ServiceRegistration. -func NewServiceRegistration(conf map[string]string, logger log.Logger, state sr.State, redirectAddr string) (sr.ServiceRegistration, error) { - +func NewServiceRegistration(conf map[string]string, logger log.Logger, state sr.State) (sr.ServiceRegistration, error) { // Allow admins to disable consul integration disableReg, ok := conf["disable_registration"] var disableRegistration bool @@ -208,7 +206,6 @@ func NewServiceRegistration(conf map[string]string, logger log.Logger, state sr. serviceAddress: serviceAddr, checkTimeout: checkTimeout, disableRegistration: disableRegistration, - redirectAddr: redirectAddr, notifyActiveCh: make(chan struct{}), notifySealedCh: make(chan struct{}), @@ -221,9 +218,9 @@ func NewServiceRegistration(conf map[string]string, logger log.Logger, state sr. return c, nil } -func (c *serviceRegistration) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup) error { +func (c *serviceRegistration) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup, redirectAddr string) error { go func() { - if err := c.runServiceRegistration(wait, shutdownCh, c.redirectAddr); err != nil { + if err := c.runServiceRegistration(wait, shutdownCh, redirectAddr); err != nil { if c.logger.IsError() { c.logger.Error(fmt.Sprintf("error running service registration: %s", err)) } @@ -290,12 +287,12 @@ func (c *serviceRegistration) runServiceRegistration(waitGroup *sync.WaitGroup, // 'server' command will wait for the below goroutine to complete waitGroup.Add(1) - go c.runEventDemuxer(waitGroup, shutdownCh, redirectAddr) + go c.runEventDemuxer(waitGroup, shutdownCh) return nil } -func (c *serviceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, shutdownCh <-chan struct{}, redirectAddr string) { +func (c *serviceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, shutdownCh <-chan struct{}) { // This defer statement should be executed last. So push it first. defer waitGroup.Done() diff --git a/serviceregistration/consul/consul_service_registration_test.go b/serviceregistration/consul/consul_service_registration_test.go index 468699041c18..811a21905021 100644 --- a/serviceregistration/consul/consul_service_registration_test.go +++ b/serviceregistration/consul/consul_service_registration_test.go @@ -32,11 +32,11 @@ func testConsulServiceRegistrationConfig(t *testing.T, conf *consulConf) *servic defer func() { close(shutdownCh) }() - be, err := NewServiceRegistration(*conf, logger, sr.State{}, "") + be, err := NewServiceRegistration(*conf, logger, sr.State{}) if err != nil { t.Fatalf("Expected Consul to initialize: %v", err) } - if err := be.Run(shutdownCh, &sync.WaitGroup{}); err != nil { + if err := be.Run(shutdownCh, &sync.WaitGroup{}, ""); err != nil { t.Fatal(err) } @@ -69,8 +69,10 @@ func TestConsul_ServiceRegistration(t *testing.T) { waitForServices := func(t *testing.T, expected map[string][]string) map[string][]string { t.Helper() // Wait for up to 10 seconds + var services map[string][]string + var err error for i := 0; i < 10; i++ { - services, _, err := client.Catalog().Services(nil) + services, _, err = client.Catalog().Services(nil) if err != nil { t.Fatal(err) } @@ -79,7 +81,7 @@ func TestConsul_ServiceRegistration(t *testing.T) { } time.Sleep(time.Second) } - t.Fatalf("Catalog Services never reached expected state %v", expected) + t.Fatalf("Catalog Services never reached: got: %v, expected state: %v", services, expected) return nil } @@ -94,11 +96,11 @@ func TestConsul_ServiceRegistration(t *testing.T) { sd, err := NewServiceRegistration(map[string]string{ "address": addr, "token": token, - }, logger, sr.State{}, redirectAddr) + }, logger, sr.State{}) if err != nil { t.Fatal(err) } - if err := sd.Run(shutdownCh, &sync.WaitGroup{}); err != nil { + if err := sd.Run(shutdownCh, &sync.WaitGroup{}, redirectAddr); err != nil { t.Fatal(err) } @@ -167,11 +169,11 @@ func TestConsul_ServiceTags(t *testing.T) { close(shutdownCh) }() - be, err := NewServiceRegistration(consulConfig, logger, sr.State{}, "") + be, err := NewServiceRegistration(consulConfig, logger, sr.State{}) if err != nil { t.Fatal(err) } - if err := be.Run(shutdownCh, &sync.WaitGroup{}); err != nil { + if err := be.Run(shutdownCh, &sync.WaitGroup{}, ""); err != nil { t.Fatal(err) } @@ -226,11 +228,11 @@ func TestConsul_ServiceAddress(t *testing.T) { shutdownCh := make(chan struct{}) logger := logging.NewVaultLogger(log.Debug) - be, err := NewServiceRegistration(test.consulConfig, logger, sr.State{}, "") + be, err := NewServiceRegistration(test.consulConfig, logger, sr.State{}) if err != nil { t.Fatalf("expected Consul to initialize: %v", err) } - if err := be.Run(shutdownCh, &sync.WaitGroup{}); err != nil { + if err := be.Run(shutdownCh, &sync.WaitGroup{}, ""); err != nil { t.Fatal(err) } @@ -355,7 +357,7 @@ func TestConsul_newConsulServiceRegistration(t *testing.T) { shutdownCh := make(chan struct{}) logger := logging.NewVaultLogger(log.Debug) - be, err := NewServiceRegistration(test.consulConfig, logger, sr.State{}, "") + be, err := NewServiceRegistration(test.consulConfig, logger, sr.State{}) if test.fail { if err == nil { t.Fatalf(`Expected config "%s" to fail`, test.name) @@ -365,7 +367,7 @@ func TestConsul_newConsulServiceRegistration(t *testing.T) { } else if !test.fail && err != nil { t.Fatalf("Expected config %s to not fail: %v", test.name, err) } - if err := be.Run(shutdownCh, &sync.WaitGroup{}); err != nil { + if err := be.Run(shutdownCh, &sync.WaitGroup{}, ""); err != nil { t.Fatal(err) } @@ -559,7 +561,7 @@ func TestConsul_serviceID(t *testing.T) { shutdownCh := make(chan struct{}) be, err := NewServiceRegistration(consulConf{ "service": test.serviceName, - }, logger, sr.State{}, "") + }, logger, sr.State{}) if !test.valid { if err == nil { t.Fatalf("expected an error initializing for name %q", test.serviceName) @@ -569,7 +571,7 @@ func TestConsul_serviceID(t *testing.T) { if test.valid && err != nil { t.Fatalf("expected Consul to initialize: %v", err) } - if err := be.Run(shutdownCh, &sync.WaitGroup{}); err != nil { + if err := be.Run(shutdownCh, &sync.WaitGroup{}, ""); err != nil { t.Fatal(err) } diff --git a/serviceregistration/kubernetes/client/client.go b/serviceregistration/kubernetes/client/client.go index afbd0b616e4c..934d3bad908c 100644 --- a/serviceregistration/kubernetes/client/client.go +++ b/serviceregistration/kubernetes/client/client.go @@ -29,9 +29,20 @@ var ( ErrNotInCluster = errors.New("unable to load in-cluster configuration, KUBERNETES_SERVICE_HOST and KUBERNETES_SERVICE_PORT must be defined") ) +// Client is a minimal Kubernetes client. We rolled our own because the existing +// Kubernetes client-go library available externally has a high number of dependencies +// and we thought it wasn't worth it for only two API calls. If at some point they break +// the client into smaller modules, or if we add quite a few methods to this client, it may +// be worthwhile to revisit that decision. +type Client struct { + logger hclog.Logger + config *Config + stopCh chan struct{} +} + // New instantiates a Client. The stopCh is used for exiting retry loops // when closed. -func New(logger hclog.Logger, stopCh <-chan struct{}) (*Client, error) { +func New(logger hclog.Logger) (*Client, error) { config, err := inClusterConfig() if err != nil { return nil, err @@ -39,19 +50,12 @@ func New(logger hclog.Logger, stopCh <-chan struct{}) (*Client, error) { return &Client{ logger: logger, config: config, - stopCh: stopCh, + stopCh: make(chan struct{}), }, nil } -// Client is a minimal Kubernetes client. We rolled our own because the existing -// Kubernetes client-go library available externally has a high number of dependencies -// and we thought it wasn't worth it for only two API calls. If at some point they break -// the client into smaller modules, or if we add quite a few methods to this client, it may -// be worthwhile to revisit that decision. -type Client struct { - logger hclog.Logger - config *Config - stopCh <-chan struct{} +func (c *Client) Shutdown() { + close(c.stopCh) } // GetPod gets a pod from the Kubernetes API. @@ -132,10 +136,13 @@ func (c *Client) do(req *http.Request, ptrToReturnObj interface{}) error { // a stop from our stopChan. This allows us to exit from our retry // loop during a shutdown, rather than hanging. ctx, cancelFunc := context.WithCancel(context.Background()) - go func(stopCh <-chan struct{}) { - <-stopCh - cancelFunc() - }(c.stopCh) + go func() { + select { + case <-ctx.Done(): + case <-c.stopCh: + cancelFunc() + } + }() retryableReq.WithContext(ctx) retryableReq.Header.Set("Authorization", "Bearer "+c.config.BearerToken) diff --git a/serviceregistration/kubernetes/client/client_test.go b/serviceregistration/kubernetes/client/client_test.go index 76d87499f090..d87cd75569d4 100644 --- a/serviceregistration/kubernetes/client/client_test.go +++ b/serviceregistration/kubernetes/client/client_test.go @@ -22,7 +22,7 @@ func TestClient(t *testing.T) { t.Fatal(err) } - client, err := New(hclog.Default(), make(chan struct{})) + client, err := New(hclog.Default()) if err != nil { t.Fatal(err) } diff --git a/serviceregistration/kubernetes/client/cmd/kubeclient/main.go b/serviceregistration/kubernetes/client/cmd/kubeclient/main.go index 536b63f22ee4..bca8394b06a5 100644 --- a/serviceregistration/kubernetes/client/cmd/kubeclient/main.go +++ b/serviceregistration/kubernetes/client/cmd/kubeclient/main.go @@ -21,7 +21,10 @@ import ( "encoding/json" "flag" "fmt" + "os" + "os/signal" "strings" + "syscall" "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/serviceregistration/kubernetes/client" @@ -42,39 +45,64 @@ func init() { func main() { flag.Parse() - c, err := client.New(hclog.Default(), make(chan struct{})) + c, err := client.New(hclog.Default()) if err != nil { panic(err) } - switch callToMake { - case "get-pod": - pod, err := c.GetPod(namespace, podName) - if err != nil { - panic(err) - } - b, _ := json.Marshal(pod) - fmt.Printf("pod: %s\n", b) - return - case "patch-pod": - patchPairs := strings.Split(patchesToAdd, ",") - var patches []*client.Patch - for _, patchPair := range patchPairs { - fields := strings.Split(patchPair, ":") - if len(fields) != 2 { - panic(fmt.Errorf("unable to split %s from selectors provided of %s", fields, patchesToAdd)) + reqCh := make(chan struct{}) + shutdownCh := makeShutdownCh() + + go func() { + defer close(reqCh) + + switch callToMake { + case "get-pod": + pod, err := c.GetPod(namespace, podName) + if err != nil { + panic(err) } - patches = append(patches, &client.Patch{ - Operation: client.Replace, - Path: fields[0], - Value: fields[1], - }) - } - if err := c.PatchPod(namespace, podName, patches...); err != nil { - panic(err) + b, _ := json.Marshal(pod) + fmt.Printf("pod: %s\n", b) + return + case "patch-pod": + patchPairs := strings.Split(patchesToAdd, ",") + var patches []*client.Patch + for _, patchPair := range patchPairs { + fields := strings.Split(patchPair, ":") + if len(fields) != 2 { + panic(fmt.Errorf("unable to split %s from selectors provided of %s", fields, patchesToAdd)) + } + patches = append(patches, &client.Patch{ + Operation: client.Replace, + Path: fields[0], + Value: fields[1], + }) + } + if err := c.PatchPod(namespace, podName, patches...); err != nil { + panic(err) + } + return + default: + panic(fmt.Errorf(`unsupported call provided: %q`, callToMake)) } - return - default: - panic(fmt.Errorf(`unsupported call provided: %q`, callToMake)) + }() + + select { + case <-shutdownCh: + fmt.Println("Interrupt received, exiting...") + case <-reqCh: } } + +func makeShutdownCh() chan struct{} { + resultCh := make(chan struct{}) + + shutdownCh := make(chan os.Signal, 4) + signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM) + go func() { + <-shutdownCh + close(resultCh) + }() + return resultCh +} diff --git a/serviceregistration/kubernetes/retry_handler.go b/serviceregistration/kubernetes/retry_handler.go index 67381bc734d5..21d71ac8faf7 100644 --- a/serviceregistration/kubernetes/retry_handler.go +++ b/serviceregistration/kubernetes/retry_handler.go @@ -2,11 +2,14 @@ package kubernetes import ( "fmt" + "strconv" "sync" "time" "github.com/hashicorp/go-hclog" + sr "github.com/hashicorp/vault/serviceregistration" "github.com/hashicorp/vault/serviceregistration/kubernetes/client" + "github.com/oklog/run" ) // How often to retry sending a state update if it fails. @@ -22,58 +25,88 @@ type retryHandler struct { // To synchronize setInitialState and patchesToRetry. lock sync.Mutex - // setInitialState will be nil if this has been done successfully. - // It must be done before any patches are retried. - setInitialState func() error + // initialStateSet determines whether an initial state has been set + // successfully or whether a state already exists. + initialStateSet bool + + // State stores an initial state to be set + initialState sr.State // The map holds the path to the label being updated. It will only either // not hold a particular label, or hold _the last_ state we were aware of. // These should only be updated after initial state has been set. patchesToRetry map[string]*client.Patch -} -func (r *retryHandler) SetInitialState(setInitialState func() error) { - r.lock.Lock() - defer r.lock.Unlock() - if err := setInitialState(); err != nil { - if r.logger.IsWarn() { - r.logger.Warn(fmt.Sprintf("unable to set initial state due to %s, will retry", err.Error())) - } - r.setInitialState = setInitialState - } + // client is the Client to use when making API calls against kubernetes + client *client.Client } // Run must be called for retries to be started. -func (r *retryHandler) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup, c *client.Client) { +func (r *retryHandler) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup) { + wait.Add(1) + + r.setInitialState(shutdownCh) + // Run this in a go func so this call doesn't block. go func() { // Make sure Vault will give us time to finish up here. - wait.Add(1) defer wait.Done() - retry := time.NewTicker(retryFreq) - defer retry.Stop() - for { - select { - case <-shutdownCh: - return - case <-retry.C: - r.retry(c) + var g run.Group + + // This run group watches for the shutdownCh + g.Add(func() error { + <-shutdownCh + return nil + }, func(error) {}) + + checkUpdateStateStop := make(chan struct{}) + g.Add(func() error { + r.periodicUpdateState(checkUpdateStateStop) + return nil + }, func(error) { + close(checkUpdateStateStop) + r.client.Shutdown() + }) + + g.Run() + }() +} + +func (r *retryHandler) setInitialState(shutdownCh <-chan struct{}) error { + r.lock.Lock() + defer r.lock.Unlock() + + doneCh := make(chan struct{}) + + go func() { + if err := r.setInitialStateInternal(); err != nil { + if r.logger.IsWarn() { + r.logger.Warn(fmt.Sprintf("unable to set initial state due to %s, will retry", err.Error())) } } + close(doneCh) }() + + // Wait until the state is set or shutdown happens + select { + case <-doneCh: + case <-shutdownCh: + } + + return nil } // Notify adds a patch to be retried until it's either completed without // error, or no longer needed. -func (r *retryHandler) Notify(c *client.Client, patch *client.Patch) { +func (r *retryHandler) Notify(patch *client.Patch) { r.lock.Lock() defer r.lock.Unlock() // Initial state must be set first, or subsequent notifications we've // received could get smashed by a late-arriving initial state. // We will store this to retry it when appropriate. - if r.setInitialState != nil { + if !r.initialStateSet { if r.logger.IsWarn() { r.logger.Warn(fmt.Sprintf("cannot notify of present state for %s because initial state is unset", patch.Path)) } @@ -82,7 +115,7 @@ func (r *retryHandler) Notify(c *client.Client, patch *client.Patch) { } // Initial state has been sent, so it's OK to attempt a patch immediately. - if err := c.PatchPod(r.namespace, r.podName, patch); err != nil { + if err := r.client.PatchPod(r.namespace, r.podName, patch); err != nil { if r.logger.IsWarn() { r.logger.Warn(fmt.Sprintf("unable to update state for %s due to %s, will retry", patch.Path, err.Error())) } @@ -90,14 +123,98 @@ func (r *retryHandler) Notify(c *client.Client, patch *client.Patch) { } } -func (r *retryHandler) retry(c *client.Client) { +// setInitialState sets the initial state remotely. This should be called with +// the lock held. +func (r *retryHandler) setInitialStateInternal() error { + // Verify that the pod exists and our configuration looks good. + pod, err := r.client.GetPod(r.namespace, r.podName) + if err != nil { + return err + } + + // Now to initially label our pod. + if pod.Metadata == nil { + // This should never happen IRL, just being defensive. + return fmt.Errorf("no pod metadata on %+v", pod) + } + if pod.Metadata.Labels == nil { + // Notify the labels field, and the labels as part of that one call. + // The reason we must take a different approach to adding them is discussed here: + // https://stackoverflow.com/questions/57480205/error-while-applying-json-patch-to-kubernetes-custom-resource + if err := r.client.PatchPod(r.namespace, r.podName, &client.Patch{ + Operation: client.Add, + Path: "/metadata/labels", + Value: map[string]string{ + labelVaultVersion: r.initialState.VaultVersion, + labelActive: strconv.FormatBool(r.initialState.IsActive), + labelSealed: strconv.FormatBool(r.initialState.IsSealed), + labelPerfStandby: strconv.FormatBool(r.initialState.IsPerformanceStandby), + labelInitialized: strconv.FormatBool(r.initialState.IsInitialized), + }, + }); err != nil { + return err + } + } else { + // Create the labels through a patch to each individual field. + patches := []*client.Patch{ + { + Operation: client.Replace, + Path: pathToLabels + labelVaultVersion, + Value: r.initialState.VaultVersion, + }, + { + Operation: client.Replace, + Path: pathToLabels + labelActive, + Value: strconv.FormatBool(r.initialState.IsActive), + }, + { + Operation: client.Replace, + Path: pathToLabels + labelSealed, + Value: strconv.FormatBool(r.initialState.IsSealed), + }, + { + Operation: client.Replace, + Path: pathToLabels + labelPerfStandby, + Value: strconv.FormatBool(r.initialState.IsPerformanceStandby), + }, + { + Operation: client.Replace, + Path: pathToLabels + labelInitialized, + Value: strconv.FormatBool(r.initialState.IsInitialized), + }, + } + if err := r.client.PatchPod(r.namespace, r.podName, patches...); err != nil { + return err + } + } + return nil +} + +func (r *retryHandler) periodicUpdateState(stopCh chan struct{}) { + retry := time.NewTicker(retryFreq) + defer retry.Stop() + + for { + // Call updateState immediately so we don't wait for the first tick + // if setting the initial state + r.updateState() + + select { + case <-stopCh: + return + case <-retry.C: + } + } +} + +func (r *retryHandler) updateState() { r.lock.Lock() defer r.lock.Unlock() // Initial state must be set first, or subsequent notifications we've // received could get smashed by a late-arriving initial state. - if r.setInitialState != nil { - if err := r.setInitialState(); err != nil { + if !r.initialStateSet { + if err := r.setInitialStateInternal(); err != nil { if r.logger.IsWarn() { r.logger.Warn(fmt.Sprintf("unable to set initial state due to %s, will retry", err.Error())) } @@ -106,7 +223,7 @@ func (r *retryHandler) retry(c *client.Client) { return } // On success, we set it to nil and allow the logic to continue. - r.setInitialState = nil + r.initialStateSet = true } if len(r.patchesToRetry) == 0 { @@ -121,7 +238,7 @@ func (r *retryHandler) retry(c *client.Client) { i++ } - if err := c.PatchPod(r.namespace, r.podName, patches...); err != nil { + if err := r.client.PatchPod(r.namespace, r.podName, patches...); err != nil { if r.logger.IsWarn() { r.logger.Warn(fmt.Sprintf("unable to update state for due to %s, will retry", err.Error())) } diff --git a/serviceregistration/kubernetes/retry_handler_test.go b/serviceregistration/kubernetes/retry_handler_test.go index 441cbaf14935..e2be809d0a09 100644 --- a/serviceregistration/kubernetes/retry_handler_test.go +++ b/serviceregistration/kubernetes/retry_handler_test.go @@ -33,13 +33,8 @@ func TestRetryHandlerSimple(t *testing.T) { logger := hclog.NewNullLogger() shutdownCh := make(chan struct{}) wait := &sync.WaitGroup{} - testPatch := &client.Patch{ - Operation: client.Add, - Path: "patch-path", - Value: "true", - } - c, err := client.New(logger, shutdownCh) + c, err := client.New(logger) if err != nil { t.Fatal(err) } @@ -49,17 +44,30 @@ func TestRetryHandlerSimple(t *testing.T) { namespace: kubetest.ExpectedNamespace, podName: kubetest.ExpectedPodName, patchesToRetry: make(map[string]*client.Patch), + client: c, + initialState: sr.State{}, } - r.Run(shutdownCh, wait, c) + r.Run(shutdownCh, wait) - if testState.NumPatches() != 0 { - t.Fatal("expected no current patches") + // Initial number of patches upon Run from setting the initial state + initStatePatches := testState.NumPatches() + if initStatePatches == 0 { + t.Fatalf("expected number of states patches after initial patches to be non-zero") } - r.Notify(c, testPatch) + + // Send a new patch + testPatch := &client.Patch{ + Operation: client.Add, + Path: "patch-path", + Value: "true", + } + r.Notify(testPatch) + // Wait ample until the next try should have occurred. <-time.NewTimer(retryFreq * 2).C - if testState.NumPatches() != 1 { - t.Fatal("expected 1 patch") + + if testState.NumPatches() != initStatePatches+1 { + t.Fatalf("expected 1 patch, got: %d", testState.NumPatches()) } } @@ -78,8 +86,7 @@ func TestRetryHandlerAdd(t *testing.T) { } logger := hclog.NewNullLogger() - shutdownCh := make(chan struct{}) - c, err := client.New(logger, shutdownCh) + c, err := client.New(logger) if err != nil { t.Fatal(err) } @@ -89,6 +96,7 @@ func TestRetryHandlerAdd(t *testing.T) { namespace: "some-namespace", podName: "some-pod-name", patchesToRetry: make(map[string]*client.Patch), + client: c, } testPatch1 := &client.Patch{ @@ -113,34 +121,34 @@ func TestRetryHandlerAdd(t *testing.T) { } // Should be able to add all 4 patches. - r.Notify(c, testPatch1) + r.Notify(testPatch1) if len(r.patchesToRetry) != 1 { t.Fatal("expected 1 patch") } - r.Notify(c, testPatch2) + r.Notify(testPatch2) if len(r.patchesToRetry) != 2 { t.Fatal("expected 2 patches") } - r.Notify(c, testPatch3) + r.Notify(testPatch3) if len(r.patchesToRetry) != 3 { t.Fatal("expected 3 patches") } - r.Notify(c, testPatch4) + r.Notify(testPatch4) if len(r.patchesToRetry) != 4 { t.Fatal("expected 4 patches") } // Adding a dupe should result in no change. - r.Notify(c, testPatch4) + r.Notify(testPatch4) if len(r.patchesToRetry) != 4 { t.Fatal("expected 4 patches") } // Adding a reversion should result in its twin being subtracted. - r.Notify(c, &client.Patch{ + r.Notify(&client.Patch{ Operation: client.Add, Path: "four", Value: "false", @@ -149,7 +157,7 @@ func TestRetryHandlerAdd(t *testing.T) { t.Fatal("expected 4 patches") } - r.Notify(c, &client.Patch{ + r.Notify(&client.Patch{ Operation: client.Add, Path: "three", Value: "false", @@ -158,7 +166,7 @@ func TestRetryHandlerAdd(t *testing.T) { t.Fatal("expected 4 patches") } - r.Notify(c, &client.Patch{ + r.Notify(&client.Patch{ Operation: client.Add, Path: "two", Value: "false", @@ -167,7 +175,7 @@ func TestRetryHandlerAdd(t *testing.T) { t.Fatal("expected 4 patches") } - r.Notify(c, &client.Patch{ + r.Notify(&client.Patch{ Operation: client.Add, Path: "one", Value: "false", @@ -201,7 +209,7 @@ func TestRetryHandlerRacesAndDeadlocks(t *testing.T) { Value: "true", } - c, err := client.New(logger, shutdownCh) + c, err := client.New(logger) if err != nil { t.Fatal(err) } @@ -211,6 +219,8 @@ func TestRetryHandlerRacesAndDeadlocks(t *testing.T) { namespace: kubetest.ExpectedNamespace, podName: kubetest.ExpectedPodName, patchesToRetry: make(map[string]*client.Patch), + initialState: sr.State{}, + client: c, } // Now hit it as quickly as possible to see if we can produce @@ -221,20 +231,12 @@ func TestRetryHandlerRacesAndDeadlocks(t *testing.T) { for i := 0; i < numRoutines; i++ { go func() { <-start - r.Notify(c, testPatch) + r.Notify(testPatch) done <- true }() go func() { <-start - r.Run(shutdownCh, wait, c) - done <- true - }() - go func() { - <-start - r.SetInitialState(func() error { - c.GetPod(kubetest.ExpectedNamespace, kubetest.ExpectedPodName) - return nil - }) + r.Run(shutdownCh, wait) done <- true }() } @@ -242,7 +244,7 @@ func TestRetryHandlerRacesAndDeadlocks(t *testing.T) { // Allow up to 5 seconds for everything to finish. timer := time.NewTimer(5 * time.Second) - for i := 0; i < numRoutines*3; i++ { + for i := 0; i < numRoutines*2; i++ { select { case <-timer.C: t.Fatal("test took too long to complete, check for deadlock") @@ -284,11 +286,11 @@ func TestRetryHandlerAPIConnectivityProblemsInitialState(t *testing.T) { IsSealed: true, IsActive: true, IsPerformanceStandby: true, - }, "") + }) if err != nil { t.Fatal(err) } - if err := reg.Run(shutdownCh, wait); err != nil { + if err := reg.Run(shutdownCh, wait, ""); err != nil { t.Fatal(err) } @@ -375,13 +377,10 @@ func TestRetryHandlerAPIConnectivityProblemsNotifications(t *testing.T) { IsSealed: false, IsActive: false, IsPerformanceStandby: false, - }, "") + }) if err != nil { t.Fatal(err) } - if err := reg.Run(shutdownCh, wait); err != nil { - t.Fatal(err) - } if err := reg.NotifyActiveStateChange(true); err != nil { t.Fatal(err) @@ -396,6 +395,10 @@ func TestRetryHandlerAPIConnectivityProblemsNotifications(t *testing.T) { t.Fatal(err) } + if err := reg.Run(shutdownCh, wait, ""); err != nil { + t.Fatal(err) + } + // At this point, since the initial state can't be set, // remotely we should have false for all these labels. patch := testState.Get(pathToLabels + labelVaultVersion) diff --git a/serviceregistration/kubernetes/service_registration.go b/serviceregistration/kubernetes/service_registration.go index 54881c91b0f9..f1c9a3c8ce40 100644 --- a/serviceregistration/kubernetes/service_registration.go +++ b/serviceregistration/kubernetes/service_registration.go @@ -23,7 +23,7 @@ const ( pathToLabels = "/metadata/labels/" ) -func NewServiceRegistration(config map[string]string, logger hclog.Logger, state sr.State, _ string) (sr.ServiceRegistration, error) { +func NewServiceRegistration(config map[string]string, logger hclog.Logger, state sr.State) (sr.ServiceRegistration, error) { namespace, err := getRequiredField(logger, config, client.EnvVarKubernetesNamespace, "namespace") if err != nil { return nil, err @@ -32,19 +32,26 @@ func NewServiceRegistration(config map[string]string, logger hclog.Logger, state if err != nil { return nil, err } + + c, err := client.New(logger) + if err != nil { + return nil, err + } + // The Vault version must be sanitized because it can contain special // characters like "+" which aren't acceptable by the Kube API. state.VaultVersion = client.Sanitize(state.VaultVersion) return &serviceRegistration{ - logger: logger, - namespace: namespace, - podName: podName, - initialState: state, + logger: logger, + namespace: namespace, + podName: podName, retryHandler: &retryHandler{ logger: logger, namespace: namespace, podName: podName, + initialState: state, patchesToRetry: make(map[string]*client.Patch), + client: c, }, }, nil } @@ -52,91 +59,16 @@ func NewServiceRegistration(config map[string]string, logger hclog.Logger, state type serviceRegistration struct { logger hclog.Logger namespace, podName string - client *client.Client - initialState sr.State retryHandler *retryHandler } -func (r *serviceRegistration) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup) error { - c, err := client.New(r.logger, shutdownCh) - if err != nil { - return err - } - r.client = c - - // Now that we've populated the client, we can begin using the retry handler. - r.retryHandler.SetInitialState(r.setInitialState) - r.retryHandler.Run(shutdownCh, wait, c) - return nil -} - -func (r *serviceRegistration) setInitialState() error { - // Verify that the pod exists and our configuration looks good. - pod, err := r.client.GetPod(r.namespace, r.podName) - if err != nil { - return err - } - - // Now to initially label our pod. - if pod.Metadata == nil { - // This should never happen IRL, just being defensive. - return fmt.Errorf("no pod metadata on %+v", pod) - } - if pod.Metadata.Labels == nil { - // Notify the labels field, and the labels as part of that one call. - // The reason we must take a different approach to adding them is discussed here: - // https://stackoverflow.com/questions/57480205/error-while-applying-json-patch-to-kubernetes-custom-resource - if err := r.client.PatchPod(r.namespace, r.podName, &client.Patch{ - Operation: client.Add, - Path: "/metadata/labels", - Value: map[string]string{ - labelVaultVersion: r.initialState.VaultVersion, - labelActive: strconv.FormatBool(r.initialState.IsActive), - labelSealed: strconv.FormatBool(r.initialState.IsSealed), - labelPerfStandby: strconv.FormatBool(r.initialState.IsPerformanceStandby), - labelInitialized: strconv.FormatBool(r.initialState.IsInitialized), - }, - }); err != nil { - return err - } - } else { - // Create the labels through a patch to each individual field. - patches := []*client.Patch{ - { - Operation: client.Replace, - Path: pathToLabels + labelVaultVersion, - Value: r.initialState.VaultVersion, - }, - { - Operation: client.Replace, - Path: pathToLabels + labelActive, - Value: strconv.FormatBool(r.initialState.IsActive), - }, - { - Operation: client.Replace, - Path: pathToLabels + labelSealed, - Value: strconv.FormatBool(r.initialState.IsSealed), - }, - { - Operation: client.Replace, - Path: pathToLabels + labelPerfStandby, - Value: strconv.FormatBool(r.initialState.IsPerformanceStandby), - }, - { - Operation: client.Replace, - Path: pathToLabels + labelInitialized, - Value: strconv.FormatBool(r.initialState.IsInitialized), - }, - } - if err := r.client.PatchPod(r.namespace, r.podName, patches...); err != nil { - return err - } - } +func (r *serviceRegistration) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup, _ string) error { + r.retryHandler.Run(shutdownCh, wait) return nil } func (r *serviceRegistration) NotifyActiveStateChange(isActive bool) error { - r.retryHandler.Notify(r.client, &client.Patch{ + r.retryHandler.Notify(&client.Patch{ Operation: client.Replace, Path: pathToLabels + labelActive, Value: strconv.FormatBool(isActive), @@ -145,7 +77,7 @@ func (r *serviceRegistration) NotifyActiveStateChange(isActive bool) error { } func (r *serviceRegistration) NotifySealedStateChange(isSealed bool) error { - r.retryHandler.Notify(r.client, &client.Patch{ + r.retryHandler.Notify(&client.Patch{ Operation: client.Replace, Path: pathToLabels + labelSealed, Value: strconv.FormatBool(isSealed), @@ -154,7 +86,7 @@ func (r *serviceRegistration) NotifySealedStateChange(isSealed bool) error { } func (r *serviceRegistration) NotifyPerformanceStandbyStateChange(isStandby bool) error { - r.retryHandler.Notify(r.client, &client.Patch{ + r.retryHandler.Notify(&client.Patch{ Operation: client.Replace, Path: pathToLabels + labelPerfStandby, Value: strconv.FormatBool(isStandby), @@ -163,7 +95,7 @@ func (r *serviceRegistration) NotifyPerformanceStandbyStateChange(isStandby bool } func (r *serviceRegistration) NotifyInitializedStateChange(isInitialized bool) error { - r.retryHandler.Notify(r.client, &client.Patch{ + r.retryHandler.Notify(&client.Patch{ Operation: client.Replace, Path: pathToLabels + labelInitialized, Value: strconv.FormatBool(isInitialized), diff --git a/serviceregistration/kubernetes/service_registration_test.go b/serviceregistration/kubernetes/service_registration_test.go index 050ab563b783..a1bf001f1642 100644 --- a/serviceregistration/kubernetes/service_registration_test.go +++ b/serviceregistration/kubernetes/service_registration_test.go @@ -44,11 +44,11 @@ func TestServiceRegistration(t *testing.T) { IsActive: true, IsPerformanceStandby: true, } - reg, err := NewServiceRegistration(config, logger, state, "") + reg, err := NewServiceRegistration(config, logger, state) if err != nil { t.Fatal(err) } - if err := reg.Run(shutdownCh, &sync.WaitGroup{}); err != nil { + if err := reg.Run(shutdownCh, &sync.WaitGroup{}, ""); err != nil { t.Fatal(err) } diff --git a/serviceregistration/service_registration.go b/serviceregistration/service_registration.go index 3a9f08b01126..d463a6b79941 100644 --- a/serviceregistration/service_registration.go +++ b/serviceregistration/service_registration.go @@ -26,7 +26,7 @@ type State struct { // The config is the key/value pairs set _inside_ the service registration config stanza. // The state is the initial state. // The redirectAddr is Vault core's RedirectAddr. -type Factory func(config map[string]string, logger log.Logger, state State, redirectAddr string) (ServiceRegistration, error) +type Factory func(config map[string]string, logger log.Logger, state State) (ServiceRegistration, error) // ServiceRegistration is an interface that advertises the state of Vault to a // service discovery network. @@ -60,7 +60,7 @@ type ServiceRegistration interface { // }() // return nil // } - Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup) error + Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup, redirectAddr string) error // NotifyActiveStateChange is used by Core to notify that this Vault // instance has changed its status on whether it's active or is diff --git a/vault/core.go b/vault/core.go index f70ec59fe413..06720d73561d 100644 --- a/vault/core.go +++ b/vault/core.go @@ -648,7 +648,7 @@ func (c *CoreConfig) Clone() *CoreConfig { // not exist. func (c *CoreConfig) GetServiceRegistration() sr.ServiceRegistration { - // Check whether there is a ServiceRegistration explictly configured + // Check whether there is a ServiceRegistration explicitly configured if c.ServiceRegistration != nil { return c.ServiceRegistration } From 4eeb49cc51d94d1db48197f59b86722a0aef6b15 Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Mon, 11 May 2020 18:40:59 -0700 Subject: [PATCH 2/6] move state check to the internal func --- .../kubernetes/retry_handler.go | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/serviceregistration/kubernetes/retry_handler.go b/serviceregistration/kubernetes/retry_handler.go index 21d71ac8faf7..5f7e19f90e14 100644 --- a/serviceregistration/kubernetes/retry_handler.go +++ b/serviceregistration/kubernetes/retry_handler.go @@ -73,7 +73,7 @@ func (r *retryHandler) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup) { }() } -func (r *retryHandler) setInitialState(shutdownCh <-chan struct{}) error { +func (r *retryHandler) setInitialState(shutdownCh <-chan struct{}) { r.lock.Lock() defer r.lock.Unlock() @@ -126,6 +126,11 @@ func (r *retryHandler) Notify(patch *client.Patch) { // setInitialState sets the initial state remotely. This should be called with // the lock held. func (r *retryHandler) setInitialStateInternal() error { + // If this is set, we return immediately + if r.initialStateSet { + return nil + } + // Verify that the pod exists and our configuration looks good. pod, err := r.client.GetPod(r.namespace, r.podName) if err != nil { @@ -187,6 +192,7 @@ func (r *retryHandler) setInitialStateInternal() error { return err } } + r.initialStateSet = true return nil } @@ -213,17 +219,14 @@ func (r *retryHandler) updateState() { // Initial state must be set first, or subsequent notifications we've // received could get smashed by a late-arriving initial state. - if !r.initialStateSet { - if err := r.setInitialStateInternal(); err != nil { - if r.logger.IsWarn() { - r.logger.Warn(fmt.Sprintf("unable to set initial state due to %s, will retry", err.Error())) - } - // On failure, we leave the initial state func populated for - // the next retry. - return + // If the state is already set, this is a no-op. + if err := r.setInitialStateInternal(); err != nil { + if r.logger.IsWarn() { + r.logger.Warn(fmt.Sprintf("unable to set initial state due to %s, will retry", err.Error())) } - // On success, we set it to nil and allow the logic to continue. - r.initialStateSet = true + // On failure, we leave the initial state func populated for + // the next retry. + return } if len(r.patchesToRetry) == 0 { From fd92eab46bac2fd202474add4178d646f63117e4 Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Mon, 11 May 2020 18:46:56 -0700 Subject: [PATCH 3/6] sr/kubernetes: update setInitialStateInternal godoc --- serviceregistration/kubernetes/retry_handler.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/serviceregistration/kubernetes/retry_handler.go b/serviceregistration/kubernetes/retry_handler.go index 5f7e19f90e14..3ed21510e770 100644 --- a/serviceregistration/kubernetes/retry_handler.go +++ b/serviceregistration/kubernetes/retry_handler.go @@ -123,8 +123,8 @@ func (r *retryHandler) Notify(patch *client.Patch) { } } -// setInitialState sets the initial state remotely. This should be called with -// the lock held. +// setInitialStateInternal sets the initial state remotely. This should be +// called with the lock held. func (r *retryHandler) setInitialStateInternal() error { // If this is set, we return immediately if r.initialStateSet { From 50523ae3cfe20c488a86d55cdcc2ceca3441df57 Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Tue, 12 May 2020 09:33:12 -0700 Subject: [PATCH 4/6] sr/kubernetes: remove return in setInitialState --- serviceregistration/kubernetes/retry_handler.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/serviceregistration/kubernetes/retry_handler.go b/serviceregistration/kubernetes/retry_handler.go index 3ed21510e770..c01453dc6a99 100644 --- a/serviceregistration/kubernetes/retry_handler.go +++ b/serviceregistration/kubernetes/retry_handler.go @@ -93,8 +93,6 @@ func (r *retryHandler) setInitialState(shutdownCh <-chan struct{}) { case <-doneCh: case <-shutdownCh: } - - return nil } // Notify adds a patch to be retried until it's either completed without From 716cb9da60fbdbbbee299d5f5c827c7a1bd3b764 Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Tue, 12 May 2020 10:37:57 -0700 Subject: [PATCH 5/6] core/test: fix mockServiceRegistration --- vault/core_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vault/core_test.go b/vault/core_test.go index f14f6f305753..e558b0b2c41f 100644 --- a/vault/core_test.go +++ b/vault/core_test.go @@ -2476,7 +2476,7 @@ type mockServiceRegistration struct { runDiscoveryCount int } -func (m *mockServiceRegistration) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup) error { +func (m *mockServiceRegistration) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup, redirectAddr string) error { m.runDiscoveryCount++ return nil } From 780ab17de786883ed5b44d645efe395542284d7b Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Thu, 14 May 2020 12:14:54 -0700 Subject: [PATCH 6/6] address review feedback --- .../kubernetes/retry_handler.go | 17 ++++++++++----- serviceregistration/service_registration.go | 21 +++++++++++-------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/serviceregistration/kubernetes/retry_handler.go b/serviceregistration/kubernetes/retry_handler.go index c01453dc6a99..68afa8cdc576 100644 --- a/serviceregistration/kubernetes/retry_handler.go +++ b/serviceregistration/kubernetes/retry_handler.go @@ -43,11 +43,10 @@ type retryHandler struct { // Run must be called for retries to be started. func (r *retryHandler) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup) { - wait.Add(1) - r.setInitialState(shutdownCh) // Run this in a go func so this call doesn't block. + wait.Add(1) go func() { // Make sure Vault will give us time to finish up here. defer wait.Done() @@ -55,10 +54,16 @@ func (r *retryHandler) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup) { var g run.Group // This run group watches for the shutdownCh + shutdownActorStop := make(chan struct{}) g.Add(func() error { - <-shutdownCh + select { + case <-shutdownCh: + case <-shutdownActorStop: + } return nil - }, func(error) {}) + }, func(error) { + close(shutdownActorStop) + }) checkUpdateStateStop := make(chan struct{}) g.Add(func() error { @@ -69,7 +74,9 @@ func (r *retryHandler) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup) { r.client.Shutdown() }) - g.Run() + if err := g.Run(); err != nil { + r.logger.Error("error encountered during periodic state update", "error", err) + } }() } diff --git a/serviceregistration/service_registration.go b/serviceregistration/service_registration.go index d463a6b79941..4093e906ead9 100644 --- a/serviceregistration/service_registration.go +++ b/serviceregistration/service_registration.go @@ -31,21 +31,24 @@ type Factory func(config map[string]string, logger log.Logger, state State) (Ser // ServiceRegistration is an interface that advertises the state of Vault to a // service discovery network. type ServiceRegistration interface { - // Run provides a shutdownCh and wait WaitGroup. The shutdownCh - // is for monitoring when a shutdown occurs and initiating any actions needed - // to leave service registration in a final state. When finished, signalling - // that with wait means that Vault will wait until complete. + // Run provides a shutdownCh, wait WaitGroup, and redirectAddr. The + // shutdownCh is for monitoring when a shutdown occurs and initiating any + // actions needed to leave service registration in a final state. When + // finished, signalling that with wait means that Vault will wait until + // complete. The redirectAddr is an optional parameter for implementations + // that might need to communicate with Vault's listener via this address. + // // Run is called just after Factory instantiation so can be relied upon // for controlling shutdown behavior. // Here is an example of its intended use: - // func Run(shutdownCh <-chan struct{}, wait sync.WaitGroup) error { + // func Run(shutdownCh <-chan struct{}, wait sync.WaitGroup, redirectAddr string) error { + // + // // Since we are going to want Vault to wait to shutdown + // // until after we do cleanup... + // wait.Add(1) // // // Run shutdown code in a goroutine so Run doesn't block. // go func(){ - // // Since we are going to want Vault to wait to shutdown - // // until after we do cleanup... - // wait.Add(1) - // // // Ensure that when this ends, no matter how it ends, // // we don't cause Vault to hang on shutdown. // defer wait.Done()