diff --git a/cmd/root.go b/cmd/root.go index 09e933f0c..cc118ba20 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -178,6 +178,27 @@ Instance Level Configuration my-project:us-central1:my-db-server \ 'my-project:us-central1:my-other-server?address=0.0.0.0&port=7000' +Health checks + + When enabling the --health-checks flag, the proxy will start an HTTP server + on localhost with three endpoints: + + - /startup: Returns 200 status when the proxy has finished starting up. + Otherwise returns 503 status. + + - /readiness: Returns 200 status when the proxy has started, has available + connections if max connections have been set with the --max-connections + flag, and when the proxy can connect to all registered instances. Otherwise, + returns a 503 status. Optionally supports a min-ready query param (e.g., + /readiness?min-ready=3) where the proxy will return a 200 status if the + proxy can connect successfully to at least min-ready number of instances. If + min-ready exceeds the number of registered instances, returns a 400. + + - /liveness: Always returns 200 status. If this endpoint is not responding, + the proxy is in a bad state and should be restarted. + + To configure the address, use --http-server. + Service Account Impersonation The proxy supports service account impersonation with the @@ -706,6 +727,8 @@ func runSignalWrapper(cmd *Command) error { notify := func() {} if cmd.healthCheck { needsHTTPServer = true + cmd.logger.Infof("Starting health check server at %s", + net.JoinHostPort(cmd.httpAddress, cmd.httpPort)) hc := healthcheck.NewCheck(p, cmd.logger) mux.HandleFunc("/startup", hc.HandleStartup) mux.HandleFunc("/readiness", hc.HandleReadiness) @@ -716,7 +739,7 @@ func runSignalWrapper(cmd *Command) error { // Start the HTTP server if anything requiring HTTP is specified. if needsHTTPServer { server := &http.Server{ - Addr: fmt.Sprintf("%s:%s", cmd.httpAddress, cmd.httpPort), + Addr: net.JoinHostPort(cmd.httpAddress, cmd.httpPort), Handler: mux, } // Start the HTTP server. diff --git a/internal/healthcheck/healthcheck.go b/internal/healthcheck/healthcheck.go index c2c083560..1491d150c 100644 --- a/internal/healthcheck/healthcheck.go +++ b/internal/healthcheck/healthcheck.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "net/http" + "strconv" "sync" "github.com/GoogleCloudPlatform/cloud-sql-proxy/v2/cloudsql" @@ -67,7 +68,7 @@ var errNotStarted = errors.New("proxy is not started") // HandleReadiness ensures the Check has been notified of successful startup, // that the proxy has not reached maximum connections, and that all connections // are healthy. -func (c *Check) HandleReadiness(w http.ResponseWriter, _ *http.Request) { +func (c *Check) HandleReadiness(w http.ResponseWriter, req *http.Request) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -88,14 +89,67 @@ func (c *Check) HandleReadiness(w http.ResponseWriter, _ *http.Request) { return } - err := c.proxy.CheckConnections(ctx) - if err != nil { + var minReady *int + q := req.URL.Query() + if v := q.Get("min-ready"); v != "" { + n, err := strconv.Atoi(v) + if err != nil { + c.logger.Errorf("[Health Check] min-ready must be a valid integer, got = %q", v) + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, "min-query must be a valid integer, got = %q", v) + return + } + if n <= 0 { + c.logger.Errorf("[Health Check] min-ready %q must be greater than zero", v) + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, "min-query must be greater than zero", v) + return + } + minReady = &n + } + + total, err := c.proxy.CheckConnections(ctx) + + switch { + case minReady != nil && *minReady > total: + // When min ready is set and exceeds total instances, 400 status. + mErr := fmt.Errorf( + "min-ready (%v) must be less than or equal to the number of registered instances (%v)", + *minReady, total, + ) + c.logger.Errorf("[Health Check] Readiness failed: %v", mErr) + + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(mErr.Error())) + return + case err != nil && minReady != nil: + // When there's an error and min ready is set, AND min ready instances + // are not ready, 503 status. + c.logger.Errorf("[Health Check] Readiness failed: %v", err) + + mErr, ok := err.(proxy.MultiErr) + if !ok { + // If the err is not a MultiErr, just return it as is. + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte(err.Error())) + return + } + + areReady := total - len(mErr) + if areReady < *minReady { + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte(err.Error())) + return + } + case err != nil: + // When there's just an error without min-ready: 503 status. c.logger.Errorf("[Health Check] Readiness failed: %v", err) w.WriteHeader(http.StatusServiceUnavailable) w.Write([]byte(err.Error())) return } + // No error cases apply, 200 status. w.WriteHeader(http.StatusOK) w.Write([]byte("ok")) } diff --git a/internal/healthcheck/healthcheck_test.go b/internal/healthcheck/healthcheck_test.go index 00c03495d..e0ebdd50e 100644 --- a/internal/healthcheck/healthcheck_test.go +++ b/internal/healthcheck/healthcheck_test.go @@ -22,8 +22,10 @@ import ( "net" "net/http" "net/http/httptest" + "net/url" "os" "strings" + "sync/atomic" "testing" "time" @@ -71,6 +73,21 @@ func (*fakeDialer) Close() error { return nil } +type flakeyDialer struct { + dialCount uint64 + fakeDialer +} + +// Dial fails on odd calls and succeeds on even calls. +func (f *flakeyDialer) Dial(_ context.Context, _ string, _ ...cloudsqlconn.DialOption) (net.Conn, error) { + c := atomic.AddUint64(&f.dialCount, 1) + if c%2 == 0 { + conn, _ := net.Pipe() + return conn, nil + } + return nil, errors.New("flakey dialer fails on odd calls") +} + type errorDialer struct { fakeDialer } @@ -79,13 +96,11 @@ func (*errorDialer) Dial(_ context.Context, _ string, _ ...cloudsqlconn.DialOpti return nil, errors.New("errorDialer always errors") } -func newProxyWithParams(t *testing.T, maxConns uint64, dialer cloudsql.Dialer) *proxy.Client { +func newProxyWithParams(t *testing.T, maxConns uint64, dialer cloudsql.Dialer, instances []proxy.InstanceConnConfig) *proxy.Client { c := &proxy.Config{ - Addr: proxyHost, - Port: proxyPort, - Instances: []proxy.InstanceConnConfig{ - {Name: "proj:region:pg"}, - }, + Addr: proxyHost, + Port: proxyPort, + Instances: instances, MaxConnections: maxConns, } p, err := proxy.NewClient(context.Background(), dialer, logger, c) @@ -96,15 +111,17 @@ func newProxyWithParams(t *testing.T, maxConns uint64, dialer cloudsql.Dialer) * } func newTestProxyWithMaxConns(t *testing.T, maxConns uint64) *proxy.Client { - return newProxyWithParams(t, maxConns, &fakeDialer{}) + return newProxyWithParams(t, maxConns, &fakeDialer{}, []proxy.InstanceConnConfig{ + {Name: "proj:region:pg"}, + }) } func newTestProxyWithDialer(t *testing.T, d cloudsql.Dialer) *proxy.Client { - return newProxyWithParams(t, 0, d) + return newProxyWithParams(t, 0, d, []proxy.InstanceConnConfig{{Name: "proj:region:pg"}}) } func newTestProxy(t *testing.T) *proxy.Client { - return newProxyWithParams(t, 0, &fakeDialer{}) + return newProxyWithParams(t, 0, &fakeDialer{}, []proxy.InstanceConnConfig{{Name: "proj:region:pg"}}) } func TestHandleStartupWhenNotNotified(t *testing.T) { @@ -117,7 +134,7 @@ func TestHandleStartupWhenNotNotified(t *testing.T) { check := healthcheck.NewCheck(p, logger) rec := httptest.NewRecorder() - check.HandleStartup(rec, &http.Request{}) + check.HandleStartup(rec, &http.Request{URL: &url.URL{}}) // Startup is not complete because the Check has not been notified of the // proxy's startup. @@ -139,7 +156,7 @@ func TestHandleStartupWhenNotified(t *testing.T) { check.NotifyStarted() rec := httptest.NewRecorder() - check.HandleStartup(rec, &http.Request{}) + check.HandleStartup(rec, &http.Request{URL: &url.URL{}}) resp := rec.Result() if got, want := resp.StatusCode, http.StatusOK; got != want { @@ -157,7 +174,7 @@ func TestHandleReadinessWhenNotNotified(t *testing.T) { check := healthcheck.NewCheck(p, logger) rec := httptest.NewRecorder() - check.HandleReadiness(rec, &http.Request{}) + check.HandleReadiness(rec, &http.Request{URL: &url.URL{}}) resp := rec.Result() if got, want := resp.StatusCode, http.StatusServiceUnavailable; got != want { @@ -193,13 +210,14 @@ func TestHandleReadinessForMaxConns(t *testing.T) { waitForConnect := func(t *testing.T, wantCode int) *http.Response { for i := 0; i < 10; i++ { rec := httptest.NewRecorder() - check.HandleReadiness(rec, &http.Request{}) + check.HandleReadiness(rec, &http.Request{URL: &url.URL{}}) resp := rec.Result() if resp.StatusCode == wantCode { return resp } time.Sleep(time.Second) } + t.Fatalf("failed to receive status code = %v", wantCode) return nil } resp := waitForConnect(t, http.StatusServiceUnavailable) @@ -224,7 +242,7 @@ func TestHandleReadinessWithConnectionProblems(t *testing.T) { check.NotifyStarted() rec := httptest.NewRecorder() - check.HandleReadiness(rec, &http.Request{}) + check.HandleReadiness(rec, &http.Request{URL: &url.URL{}}) resp := rec.Result() if got, want := resp.StatusCode, http.StatusServiceUnavailable; got != want { @@ -239,3 +257,85 @@ func TestHandleReadinessWithConnectionProblems(t *testing.T) { t.Fatalf("want substring with = %q, got = %v", want, string(body)) } } + +func TestReadinessWithMinReady(t *testing.T) { + tcs := []struct { + desc string + minReady string + wantStatus int + dialer cloudsql.Dialer + }{ + { + desc: "when min ready is zero", + minReady: "0", + wantStatus: http.StatusBadRequest, + dialer: &fakeDialer{}, + }, + { + desc: "when min ready is less than zero", + minReady: "-1", + wantStatus: http.StatusBadRequest, + dialer: &fakeDialer{}, + }, + { + desc: "when only one instance must be ready", + minReady: "1", + wantStatus: http.StatusOK, + dialer: &flakeyDialer{}, // fails on first call, succeeds on second + }, + { + desc: "when all instances must be ready", + minReady: "2", + wantStatus: http.StatusServiceUnavailable, + dialer: &errorDialer{}, + }, + { + desc: "when min ready is greater than the number of instances", + minReady: "3", + wantStatus: http.StatusBadRequest, + dialer: &fakeDialer{}, + }, + { + desc: "when min ready is bogus", + minReady: "bogus", + wantStatus: http.StatusBadRequest, + dialer: &fakeDialer{}, + }, + { + desc: "when min ready is not set", + minReady: "", + wantStatus: http.StatusOK, + dialer: &fakeDialer{}, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + p := newProxyWithParams(t, 0, + tc.dialer, + []proxy.InstanceConnConfig{ + {Name: "p:r:instance-1"}, + {Name: "p:r:instance-2"}, + }, + ) + defer func() { + if err := p.Close(); err != nil { + t.Logf("failed to close proxy client: %v", err) + } + }() + + check := healthcheck.NewCheck(p, logger) + check.NotifyStarted() + u, err := url.Parse(fmt.Sprintf("/readiness?min-ready=%s", tc.minReady)) + if err != nil { + t.Fatal(err) + } + rec := httptest.NewRecorder() + check.HandleReadiness(rec, &http.Request{URL: u}) + + resp := rec.Result() + if got, want := resp.StatusCode, tc.wantStatus; got != want { + t.Fatalf("want = %v, got = %v", want, got) + } + }) + } +} diff --git a/internal/proxy/fuse_test.go b/internal/proxy/fuse_test.go index 823570aba..2068e4618 100644 --- a/internal/proxy/fuse_test.go +++ b/internal/proxy/fuse_test.go @@ -287,9 +287,13 @@ func TestFUSECheckConnections(t *testing.T) { conn := tryDialUnix(t, filepath.Join(fuseDir, "proj:reg:mysql")) defer conn.Close() - if err := c.CheckConnections(context.Background()); err != nil { + n, err := c.CheckConnections(context.Background()) + if err != nil { t.Fatalf("c.CheckConnections(): %v", err) } + if want, got := 1, n; want != got { + t.Fatalf("CheckConnections number of connections: want = %v, got = %v", want, got) + } // verify the dialer was invoked twice, once for connect, once for check // connection diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 406b32c07..7ba1086ce 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -438,9 +438,9 @@ func NewClient(ctx context.Context, d cloudsql.Dialer, l cloudsql.Logger, conf * return c, nil } -// CheckConnections dials each registered instance and reports any errors that -// may have occurred. -func (c *Client) CheckConnections(ctx context.Context) error { +// CheckConnections dials each registered instance and reports the number of +// connections checked and any errors that may have occurred. +func (c *Client) CheckConnections(ctx context.Context) (int, error) { var ( wg sync.WaitGroup errCh = make(chan error, len(c.mnts)) @@ -460,14 +460,17 @@ func (c *Client) CheckConnections(ctx context.Context) error { } cErr := conn.Close() if cErr != nil { - errCh <- fmt.Errorf("%v: %v", m.inst, cErr) + c.logger.Errorf( + "connection check failed to close connection for %v: %v", + m.inst, cErr, + ) } }(mnt) } wg.Wait() var mErr MultiErr - for i := 0; i < len(c.mnts); i++ { + for i := 0; i < len(mnts); i++ { select { case err := <-errCh: mErr = append(mErr, err) @@ -475,10 +478,11 @@ func (c *Client) CheckConnections(ctx context.Context) error { continue } } + mLen := len(mnts) if len(mErr) > 0 { - return mErr + return mLen, mErr } - return nil + return mLen, nil } // ConnCount returns the number of open connections and the maximum allowed diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 1f04cb6b1..86ceca99e 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -586,9 +586,13 @@ func TestCheckConnections(t *testing.T) { defer c.Close() go c.Serve(context.Background(), func() {}) - if err = c.CheckConnections(context.Background()); err != nil { + n, err := c.CheckConnections(context.Background()) + if err != nil { t.Fatalf("CheckConnections failed: %v", err) } + if want, got := len(in.Instances), n; want != got { + t.Fatalf("CheckConnections number of connections: want = %v, got = %v", want, got) + } if want, got := 1, d.dialAttempts(); want != got { t.Fatalf("dial attempts: want = %v, got = %v", want, got) @@ -610,8 +614,11 @@ func TestCheckConnections(t *testing.T) { defer c.Close() go c.Serve(context.Background(), func() {}) - err = c.CheckConnections(context.Background()) + n, err = c.CheckConnections(context.Background()) if err == nil { t.Fatal("CheckConnections should have failed, but did not") } + if want, got := len(in.Instances), n; want != got { + t.Fatalf("CheckConnections number of connections: want = %v, got = %v", want, got) + } }