diff --git a/README.md b/README.md index 229224f3..fad143a1 100644 --- a/README.md +++ b/README.md @@ -279,18 +279,22 @@ for possible values. ## Localhost Admin Server -The Proxy includes support for an admin server on localhost. By default, the -admin server is not enabled. To enable the server, pass the `--debug` flag. -This will start the server on localhost at port 9091. To change the port, use -the `--admin-port` flag. +The Proxy includes support for an admin server on localhost. By default, +the admin server is not enabled. To enable the server, pass the --debug or +--quitquitquit flag. This will start the server on localhost at port 9091. +To change the port, use the --admin-port flag. -The admin server includes Go's pprof tool and is available at `/debug/pprof/`. +When --debug is set, the admin server enables Go's profiler available at +/debug/pprof/. See the [documentation on pprof][pprof] for details on how to use the profiler. -[pprof]: https://pkg.go.dev/net/http/pprof. +When --quitquitquit is set, the admin server adds an endpoint at +/quitquitquit. The admin server exits gracefully when it receives a POST +request at /quitquitquit. +[pprof]: https://pkg.go.dev/net/http/pprof. ## Support policy diff --git a/cmd/errors.go b/cmd/errors.go index 5a2bbf9c..8739eb53 100644 --- a/cmd/errors.go +++ b/cmd/errors.go @@ -28,6 +28,11 @@ var ( Err: errors.New("SIGTERM signal received"), Code: 143, } + + errQuitQuitQuit = &exitError{ + Err: errors.New("/quitquitquit received request"), + Code: 0, // This error guarantees a clean exit. + } ) func newBadCommandError(msg string) error { diff --git a/cmd/root.go b/cmd/root.go index 0f4806c5..22212eaa 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -29,6 +29,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "syscall" "time" @@ -177,9 +178,9 @@ Instance Level Configuration When necessary, you may specify the full path to a Unix socket. Set the unix-socket-path query parameter to the absolute path of the Unix socket for - the database instance. The parent directory of the unix-socket-path must + the database instance. The parent directory of the unix-socket-path must exist when the proxy starts or else socket creation will fail. For Postgres - instances, the proxy will ensure that the last path element is + instances, the proxy will ensure that the last path element is '.s.PGSQL.5432' appending it if necessary. For example, ./cloud-sql-proxy \ @@ -268,15 +269,20 @@ Configuration using environment variables Localhost Admin Server The Proxy includes support for an admin server on localhost. By default, - the admin server is not enabled. To enable the server, pass the --debug - flag. This will start the server on localhost at port 9091. To change the - port, use the --admin-port flag. + the admin server is not enabled. To enable the server, pass the --debug or + --quitquitquit flag. This will start the server on localhost at port 9091. + To change the port, use the --admin-port flag. - The admin server includes Go's pprof tool and is available at + When --debug is set, the admin server enables Go's profiler available at /debug/pprof/. See the documentation on pprof for details on how to use the profiler at https://pkg.go.dev/net/http/pprof. + + When --quitquitquit is set, the admin server adds an endpoint at + /quitquitquit. The admin server exits gracefully when it receives a POST + request at /quitquitquit. + ` const envPrefix = "ALLOYDB_PROXY" @@ -411,7 +417,9 @@ the maximum time has passed. Defaults to 0s.`) pflags.StringVar(&c.conf.HTTPPort, "http-port", "9090", "Port for the Prometheus server to use") pflags.BoolVar(&c.conf.Debug, "debug", false, - "Enable the admin server on localhost") + "Enable pprof on the localhost admin server") + pflags.BoolVar(&c.conf.QuitQuitQuit, "quitquitquit", false, + "Enable quitquitquit endpoint on the localhost admin server") pflags.StringVar(&c.conf.AdminPort, "admin-port", "9091", "Port for localhost-only admin server") pflags.BoolVar(&c.conf.HealthCheck, "health-check", false, @@ -619,7 +627,7 @@ func parseConfig(cmd *Command, conf *proxy.Config, args []string) error { } // runSignalWrapper watches for SIGTERM and SIGINT and interupts execution if necessary. -func runSignalWrapper(cmd *Command) error { +func runSignalWrapper(cmd *Command) (err error) { defer cmd.cleanup() ctx, cancel := context.WithCancel(cmd.Context()) defer cancel() @@ -653,21 +661,6 @@ func runSignalWrapper(cmd *Command) error { }() } - var ( - needsHTTPServer bool - mux = http.NewServeMux() - ) - if cmd.conf.Prometheus { - needsHTTPServer = true - e, err := prometheus.NewExporter(prometheus.Options{ - Namespace: cmd.conf.PrometheusNamespace, - }) - if err != nil { - return err - } - mux.Handle("/metrics", e) - } - shutdownCh := make(chan error) // watch for sigterm / sigint signals signals := make(chan os.Signal, 1) @@ -711,10 +704,27 @@ func runSignalWrapper(cmd *Command) error { defer func() { if cErr := p.Close(); cErr != nil { cmd.logger.Errorf("error during shutdown: %v", cErr) + // Capture error from close to propagate it to the caller. + err = cErr } }() - notify := func() {} + var ( + needsHTTPServer bool + mux = http.NewServeMux() + notify = func() {} + ) + if cmd.conf.Prometheus { + needsHTTPServer = true + e, err := prometheus.NewExporter(prometheus.Options{ + Namespace: cmd.conf.PrometheusNamespace, + }) + if err != nil { + return err + } + mux.Handle("/metrics", e) + } + if cmd.conf.HealthCheck { needsHTTPServer = true cmd.logger.Infof("Starting health check server at %s", @@ -725,54 +735,51 @@ func runSignalWrapper(cmd *Command) error { mux.HandleFunc("/liveness", hc.HandleLiveness) notify = hc.NotifyStarted } + // Start the HTTP server if anything requiring HTTP is specified. + if needsHTTPServer { + go startHTTPServer( + ctx, + cmd.logger, + net.JoinHostPort(cmd.conf.HTTPAddress, cmd.conf.HTTPPort), + mux, + shutdownCh, + ) + } - go func() { - if !cmd.conf.Debug { - return - } - m := http.NewServeMux() + var ( + needsAdminServer bool + m = http.NewServeMux() + ) + if cmd.conf.QuitQuitQuit { + needsAdminServer = true + cmd.logger.Infof("Enabling quitquitquit endpoint at localhost:%v", cmd.conf.AdminPort) + // quitquitquit allows for shutdown on localhost only. + var quitOnce sync.Once + m.HandleFunc("/quitquitquit", quitquitquit(&quitOnce, shutdownCh)) + } + if cmd.conf.Debug { + needsAdminServer = true + cmd.logger.Infof("Enabling pprof endpoints at localhost:%v", cmd.conf.AdminPort) + // pprof standard endpoints m.HandleFunc("/debug/pprof/", pprof.Index) m.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) m.HandleFunc("/debug/pprof/profile", pprof.Profile) m.HandleFunc("/debug/pprof/symbol", pprof.Symbol) m.HandleFunc("/debug/pprof/trace", pprof.Trace) - addr := net.JoinHostPort("localhost", cmd.conf.AdminPort) - cmd.logger.Infof("Starting admin server on %v", addr) - if lErr := http.ListenAndServe(addr, m); lErr != nil { - cmd.logger.Errorf("Failed to start admin HTTP server: %v", lErr) - } - }() - // Start the HTTP server if anything requiring HTTP is specified. - if needsHTTPServer { - server := &http.Server{ - Addr: net.JoinHostPort(cmd.conf.HTTPAddress, cmd.conf.HTTPPort), - Handler: mux, - } - // Start the HTTP server. - go func() { - err := server.ListenAndServe() - if err == http.ErrServerClosed { - return - } - if err != nil { - shutdownCh <- fmt.Errorf("failed to start HTTP server: %v", err) - } - }() - // Handle shutdown of the HTTP server gracefully. - go func() { - <-ctx.Done() - // Give the HTTP server a second to shutdown cleanly. - ctx2, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if err := server.Shutdown(ctx2); err != nil { - cmd.logger.Errorf("failed to shutdown Prometheus HTTP server: %v\n", err) - } - }() + } + if needsAdminServer { + go startHTTPServer( + ctx, + cmd.logger, + net.JoinHostPort("localhost", cmd.conf.AdminPort), + m, + shutdownCh, + ) } go func() { shutdownCh <- p.Serve(ctx, notify) }() - err := <-shutdownCh + err = <-shutdownCh switch { case errors.Is(err, errSigInt): cmd.logger.Errorf("SIGINT signal received. Shutting down...") @@ -783,3 +790,45 @@ func runSignalWrapper(cmd *Command) error { } return err } + +func quitquitquit(quitOnce *sync.Once, shutdownCh chan<- error) http.HandlerFunc { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + rw.WriteHeader(400) + return + } + quitOnce.Do(func() { + select { + case shutdownCh <- errQuitQuitQuit: + default: + // The write attempt to shutdownCh failed and + // the proxy is already exiting. + } + }) + }) +} + +func startHTTPServer(ctx context.Context, l alloydb.Logger, addr string, mux *http.ServeMux, shutdownCh chan<- error) { + server := &http.Server{ + Addr: addr, + Handler: mux, + } + // Start the HTTP server. + go func() { + err := server.ListenAndServe() + if err == http.ErrServerClosed { + return + } + if err != nil { + shutdownCh <- fmt.Errorf("failed to start HTTP server: %v", err) + } + }() + // Handle shutdown of the HTTP server gracefully. + <-ctx.Done() + // Give the HTTP server a second to shutdown cleanly. + ctx2, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := server.Shutdown(ctx2); err != nil { + l.Errorf("failed to shutdown HTTP server: %v\n", err) + } +} diff --git a/cmd/root_linux_test.go b/cmd/root_linux_test.go index cfa1441e..f95e2914 100644 --- a/cmd/root_linux_test.go +++ b/cmd/root_linux_test.go @@ -32,14 +32,14 @@ func TestNewCommandArgumentsOnLinux(t *testing.T) { }{ { desc: "using the fuse flag", - args: []string{"--fuse", "/cloudsql"}, - wantDir: "/cloudsql", + args: []string{"--fuse", "/alloydb"}, + wantDir: "/alloydb", wantTempDir: defaultTmp, }, { desc: "using the fuse temporary directory flag", - args: []string{"--fuse", "/cloudsql", "--fuse-tmp-dir", "/mycooldir"}, - wantDir: "/cloudsql", + args: []string{"--fuse", "/alloydb", "--fuse-tmp-dir", "/mycooldir"}, + wantDir: "/alloydb", wantTempDir: "/mycooldir", }, } diff --git a/cmd/root_test.go b/cmd/root_test.go index 37799f05..ee2c9289 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -20,6 +20,7 @@ import ( "fmt" "net" "net/http" + "net/url" "os" "path/filepath" "strings" @@ -324,6 +325,14 @@ func TestNewCommandArguments(t *testing.T) { AdminPort: "7777", }), }, + { + desc: "using the quitquitquit flag", + args: []string{"--quitquitquit", + "projects/proj/locations/region/clusters/clust/instances/inst"}, + want: withDefaults(&proxy.Config{ + QuitQuitQuit: true, + }), + }, } for _, tc := range tcs { @@ -539,6 +548,14 @@ func TestNewCommandWithEnvironmentConfig(t *testing.T) { AdminPort: "7777", }), }, + { + desc: "using the quitquitquit envvar", + envName: "ALLOYDB_PROXY_QUITQUITQUIT", + envValue: "true", + want: withDefaults(&proxy.Config{ + QuitQuitQuit: true, + }), + }, } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { @@ -899,17 +916,22 @@ func TestCommandWithCustomDialer(t *testing.T) { }, 10) } -func tryDial(addr string) (*http.Response, error) { +func tryDial(method, addr string) (*http.Response, error) { var ( resp *http.Response attempts int err error ) + u, err := url.Parse(addr) + if err != nil { + return nil, err + } + req := &http.Request{Method: method, URL: u} for { if attempts > 10 { return resp, err } - resp, err = http.Get(addr) + resp, err = http.DefaultClient.Do(req) if err != nil { attempts++ time.Sleep(time.Second) @@ -933,7 +955,7 @@ func TestPrometheusMetricsEndpoint(t *testing.T) { // try to dial metrics server for a max of ~10s to give the proxy time to // start up. - resp, err := tryDial("http://localhost:9090/metrics") // default port set by http-port flag + resp, err := tryDial("GET", "http://localhost:9090/metrics") // default port set by http-port flag if err != nil { t.Fatalf("failed to dial metrics endpoint: %v", err) } @@ -952,11 +974,83 @@ func TestPProfServer(t *testing.T) { defer cancel() go c.ExecuteContext(ctx) - resp, err := tryDial("http://localhost:9191/debug/pprof/") + resp, err := tryDial("GET", "http://localhost:9191/debug/pprof/") + if err != nil { + t.Fatalf("failed to dial endpoint: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected a 200 status, got = %v", resp.StatusCode) + } +} + +func TestQuitQuitQuit(t *testing.T) { + c := NewCommand(WithDialer(&spyDialer{})) + c.SilenceUsage = true + c.SilenceErrors = true + c.SetArgs([]string{"--quitquitquit", "--admin-port", "9192", + "projects/proj/locations/region/clusters/clust/instances/inst"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errCh := make(chan error) + go func() { + err := c.ExecuteContext(ctx) + errCh <- err + }() + resp, err := tryDial("GET", "http://localhost:9192/quitquitquit") + if err != nil { + t.Fatalf("failed to dial endpoint: %v", err) + } + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected a 400 status, got = %v", resp.StatusCode) + } + resp, err = http.Post("http://localhost:9192/quitquitquit", "", nil) if err != nil { t.Fatalf("failed to dial endpoint: %v", err) } if resp.StatusCode != http.StatusOK { t.Fatalf("expected a 200 status, got = %v", resp.StatusCode) } + if want, got := errQuitQuitQuit, <-errCh; !errors.Is(got, want) { + t.Fatalf("want = %v, got = %v", want, got) + } +} + +type errorDialer struct { + spyDialer +} + +var errCloseFailed = errors.New("close failed") + +func (*errorDialer) Close() error { + return errCloseFailed +} + +func TestQuitQuitQuitWithErrors(t *testing.T) { + c := NewCommand(WithDialer(&errorDialer{})) + c.SilenceUsage = true + c.SilenceErrors = true + c.SetArgs([]string{ + "--quitquitquit", "--admin-port", "9193", + "projects/proj/locations/region/clusters/clust/instances/inst"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errCh := make(chan error) + go func() { + err := c.ExecuteContext(ctx) + errCh <- err + }() + resp, err := tryDial("POST", "http://localhost:9193/quitquitquit") + if err != nil { + t.Fatalf("failed to dial endpoint: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected a 200 status, got = %v", resp.StatusCode) + } + // The returned error is the error from closing the dialer. + got := <-errCh + if !strings.Contains(got.Error(), "close failed") { + t.Fatalf("want = %v, got = %v", errCloseFailed, got) + } } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index f299d174..4a626e3e 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -159,6 +159,9 @@ type Config struct { // Debug enables a debug handler on localhost. Debug bool + // QuitQuitQuit enables a handler that will shut the Proxy down upon + // receiving a POST request. + QuitQuitQuit bool // OtherUserAgents is a list of space separate user agents that will be // appended to the default user agent.