diff --git a/cmd/root.go b/cmd/root.go index 8f65f613d..b0c2a4d35 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1062,8 +1062,10 @@ func runSignalWrapper(cmd *Command) (err error) { var ( needsHTTPServer bool mux = http.NewServeMux() - notify = func() {} + notifyStarted = func() {} + notifyStopped = func() {} ) + if cmd.conf.Prometheus { needsHTTPServer = true e, err := prometheus.NewExporter(prometheus.Options{ @@ -1083,8 +1085,10 @@ func runSignalWrapper(cmd *Command) (err error) { mux.HandleFunc("/startup", hc.HandleStartup) mux.HandleFunc("/readiness", hc.HandleReadiness) mux.HandleFunc("/liveness", hc.HandleLiveness) - notify = hc.NotifyStarted + notifyStarted = hc.NotifyStarted + notifyStopped = hc.NotifyStopped } + defer notifyStopped() // Start the HTTP server if anything requiring HTTP is specified. if needsHTTPServer { go startHTTPServer( @@ -1127,7 +1131,7 @@ func runSignalWrapper(cmd *Command) (err error) { ) } - go func() { shutdownCh <- p.Serve(ctx, notify) }() + go func() { shutdownCh <- p.Serve(ctx, notifyStarted) }() err = <-shutdownCh switch { diff --git a/internal/healthcheck/healthcheck.go b/internal/healthcheck/healthcheck.go index 9192970b2..f0dedd6d7 100644 --- a/internal/healthcheck/healthcheck.go +++ b/internal/healthcheck/healthcheck.go @@ -16,11 +16,9 @@ package healthcheck import ( - "context" "errors" "fmt" "net/http" - "strconv" "sync" "github.com/GoogleCloudPlatform/cloud-sql-proxy/v2/cloudsql" @@ -30,25 +28,34 @@ import ( // Check provides HTTP handlers for use as healthchecks typically in a // Kubernetes context. type Check struct { - once *sync.Once - started chan struct{} - proxy *proxy.Client - logger cloudsql.Logger + startedOnce *sync.Once + started chan struct{} + stoppedOnce *sync.Once + stopped chan struct{} + proxy *proxy.Client + logger cloudsql.Logger } // NewCheck is the initializer for Check. func NewCheck(p *proxy.Client, l cloudsql.Logger) *Check { return &Check{ - once: &sync.Once{}, - started: make(chan struct{}), - proxy: p, - logger: l, + startedOnce: &sync.Once{}, + started: make(chan struct{}), + stoppedOnce: &sync.Once{}, + stopped: make(chan struct{}), + proxy: p, + logger: l, } } // NotifyStarted notifies the check that the proxy has started up successfully. func (c *Check) NotifyStarted() { - c.once.Do(func() { close(c.started) }) + c.startedOnce.Do(func() { close(c.started) }) +} + +// NotifyStopped notifies the check that the proxy has started up successfully. +func (c *Check) NotifyStopped() { + c.stoppedOnce.Do(func() { close(c.stopped) }) } // HandleStartup reports whether the Check has been notified of startup. @@ -63,15 +70,15 @@ func (c *Check) HandleStartup(w http.ResponseWriter, _ *http.Request) { } } -var errNotStarted = errors.New("proxy is not started") +var ( + errNotStarted = errors.New("proxy is not started") + errStopped = errors.New("proxy has stopped") +) // HandleReadiness ensures the Check has been notified of successful startup, -// that the proxy has not reached maximum connections, and that all connections -// are healthy. -func (c *Check) HandleReadiness(w http.ResponseWriter, req *http.Request) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - +// that the proxy has not reached maximum connections, and that the Proxy has +// not started shutting down. +func (c *Check) HandleReadiness(w http.ResponseWriter, _ *http.Request) { select { case <-c.started: default: @@ -81,68 +88,17 @@ func (c *Check) HandleReadiness(w http.ResponseWriter, req *http.Request) { return } - if open, max := c.proxy.ConnCount(); max > 0 && open == max { - err := fmt.Errorf("max connections reached (open = %v, max = %v)", open, max) - c.logger.Errorf("[Health Check] Readiness failed: %v", err) + select { + case <-c.stopped: + c.logger.Errorf("[Health Check] Readiness failed: %v", errStopped) w.WriteHeader(http.StatusServiceUnavailable) - w.Write([]byte(err.Error())) + w.Write([]byte(errStopped.Error())) return + default: } - var minReady *int - q := req.URL.Query() - if v := q.Get("min-ready"); v != "" { - n, err := strconv.Atoi(v) - if err != nil { - c.logger.Errorf("[Health Check] min-ready must be a valid integer, got = %q", v) - w.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(w, "min-query must be a valid integer, got = %q", v) - return - } - if n <= 0 { - c.logger.Errorf("[Health Check] min-ready %q must be greater than zero", v) - w.WriteHeader(http.StatusBadRequest) - fmt.Fprint(w, "min-query must be greater than zero", v) - return - } - minReady = &n - } - - total, err := c.proxy.CheckConnections(ctx) - - switch { - case minReady != nil && *minReady > total: - // When min ready is set and exceeds total instances, 400 status. - mErr := fmt.Errorf( - "min-ready (%v) must be less than or equal to the number of registered instances (%v)", - *minReady, total, - ) - c.logger.Errorf("[Health Check] Readiness failed: %v", mErr) - - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte(mErr.Error())) - return - case err != nil && minReady != nil: - // When there's an error and min ready is set, AND min ready instances - // are not ready, 503 status. - c.logger.Errorf("[Health Check] Readiness failed: %v", err) - - mErr, ok := err.(proxy.MultiErr) - if !ok { - // If the err is not a MultiErr, just return it as is. - w.WriteHeader(http.StatusServiceUnavailable) - w.Write([]byte(err.Error())) - return - } - - areReady := total - len(mErr) - if areReady < *minReady { - w.WriteHeader(http.StatusServiceUnavailable) - w.Write([]byte(err.Error())) - return - } - case err != nil: - // When there's just an error without min-ready: 503 status. + if open, max := c.proxy.ConnCount(); max > 0 && open == max { + err := fmt.Errorf("max connections reached (open = %v, max = %v)", open, max) c.logger.Errorf("[Health Check] Readiness failed: %v", err) w.WriteHeader(http.StatusServiceUnavailable) w.Write([]byte(err.Error())) diff --git a/internal/healthcheck/healthcheck_test.go b/internal/healthcheck/healthcheck_test.go index e0ebdd50e..b4729bb3f 100644 --- a/internal/healthcheck/healthcheck_test.go +++ b/internal/healthcheck/healthcheck_test.go @@ -16,7 +16,6 @@ package healthcheck_test import ( "context" - "errors" "fmt" "io" "net" @@ -25,7 +24,6 @@ import ( "net/url" "os" "strings" - "sync/atomic" "testing" "time" @@ -73,29 +71,6 @@ func (*fakeDialer) Close() error { return nil } -type flakeyDialer struct { - dialCount uint64 - fakeDialer -} - -// Dial fails on odd calls and succeeds on even calls. -func (f *flakeyDialer) Dial(_ context.Context, _ string, _ ...cloudsqlconn.DialOption) (net.Conn, error) { - c := atomic.AddUint64(&f.dialCount, 1) - if c%2 == 0 { - conn, _ := net.Pipe() - return conn, nil - } - return nil, errors.New("flakey dialer fails on odd calls") -} - -type errorDialer struct { - fakeDialer -} - -func (*errorDialer) Dial(_ context.Context, _ string, _ ...cloudsqlconn.DialOption) (net.Conn, error) { - return nil, errors.New("errorDialer always errors") -} - func newProxyWithParams(t *testing.T, maxConns uint64, dialer cloudsql.Dialer, instances []proxy.InstanceConnConfig) *proxy.Client { c := &proxy.Config{ Addr: proxyHost, @@ -116,10 +91,6 @@ func newTestProxyWithMaxConns(t *testing.T, maxConns uint64) *proxy.Client { }) } -func newTestProxyWithDialer(t *testing.T, d cloudsql.Dialer) *proxy.Client { - return newProxyWithParams(t, 0, d, []proxy.InstanceConnConfig{{Name: "proj:region:pg"}}) -} - func newTestProxy(t *testing.T) *proxy.Client { return newProxyWithParams(t, 0, &fakeDialer{}, []proxy.InstanceConnConfig{{Name: "proj:region:pg"}}) } @@ -182,6 +153,27 @@ func TestHandleReadinessWhenNotNotified(t *testing.T) { } } +func TestHandleReadinessWhenStopped(t *testing.T) { + p := newTestProxy(t) + defer func() { + if err := p.Close(); err != nil { + t.Logf("failed to close proxy client: %v", err) + } + }() + check := healthcheck.NewCheck(p, logger) + + check.NotifyStarted() // The Proxy has started. + check.NotifyStopped() // And now the Proxy is shutting down. + + rec := httptest.NewRecorder() + check.HandleReadiness(rec, &http.Request{URL: &url.URL{}}) + + resp := rec.Result() + if got, want := resp.StatusCode, http.StatusServiceUnavailable; got != want { + t.Fatalf("want = %v, got = %v", want, got) + } +} + func TestHandleReadinessForMaxConns(t *testing.T) { p := newTestProxyWithMaxConns(t, 1) defer func() { @@ -230,112 +222,3 @@ func TestHandleReadinessForMaxConns(t *testing.T) { t.Fatalf("want max connections error, got = %v", string(body)) } } - -func TestHandleReadinessWithConnectionProblems(t *testing.T) { - p := newTestProxyWithDialer(t, &errorDialer{}) // error dialer will error on dial - defer func() { - if err := p.Close(); err != nil { - t.Logf("failed to close proxy client: %v", err) - } - }() - check := healthcheck.NewCheck(p, logger) - check.NotifyStarted() - - rec := httptest.NewRecorder() - check.HandleReadiness(rec, &http.Request{URL: &url.URL{}}) - - resp := rec.Result() - if got, want := resp.StatusCode, http.StatusServiceUnavailable; got != want { - t.Fatalf("want = %v, got = %v", want, got) - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("failed to read response body: %v", err) - } - if want := "errorDialer"; !strings.Contains(string(body), want) { - t.Fatalf("want substring with = %q, got = %v", want, string(body)) - } -} - -func TestReadinessWithMinReady(t *testing.T) { - tcs := []struct { - desc string - minReady string - wantStatus int - dialer cloudsql.Dialer - }{ - { - desc: "when min ready is zero", - minReady: "0", - wantStatus: http.StatusBadRequest, - dialer: &fakeDialer{}, - }, - { - desc: "when min ready is less than zero", - minReady: "-1", - wantStatus: http.StatusBadRequest, - dialer: &fakeDialer{}, - }, - { - desc: "when only one instance must be ready", - minReady: "1", - wantStatus: http.StatusOK, - dialer: &flakeyDialer{}, // fails on first call, succeeds on second - }, - { - desc: "when all instances must be ready", - minReady: "2", - wantStatus: http.StatusServiceUnavailable, - dialer: &errorDialer{}, - }, - { - desc: "when min ready is greater than the number of instances", - minReady: "3", - wantStatus: http.StatusBadRequest, - dialer: &fakeDialer{}, - }, - { - desc: "when min ready is bogus", - minReady: "bogus", - wantStatus: http.StatusBadRequest, - dialer: &fakeDialer{}, - }, - { - desc: "when min ready is not set", - minReady: "", - wantStatus: http.StatusOK, - dialer: &fakeDialer{}, - }, - } - for _, tc := range tcs { - t.Run(tc.desc, func(t *testing.T) { - p := newProxyWithParams(t, 0, - tc.dialer, - []proxy.InstanceConnConfig{ - {Name: "p:r:instance-1"}, - {Name: "p:r:instance-2"}, - }, - ) - defer func() { - if err := p.Close(); err != nil { - t.Logf("failed to close proxy client: %v", err) - } - }() - - check := healthcheck.NewCheck(p, logger) - check.NotifyStarted() - u, err := url.Parse(fmt.Sprintf("/readiness?min-ready=%s", tc.minReady)) - if err != nil { - t.Fatal(err) - } - rec := httptest.NewRecorder() - check.HandleReadiness(rec, &http.Request{URL: u}) - - resp := rec.Result() - if got, want := resp.StatusCode, tc.wantStatus; got != want { - t.Fatalf("want = %v, got = %v", want, got) - } - }) - } -}