diff --git a/cmdutil/v2/debug/config.go b/cmdutil/v2/debug/config.go new file mode 100644 index 00000000..360e294f --- /dev/null +++ b/cmdutil/v2/debug/config.go @@ -0,0 +1,6 @@ +package debug + +// Config describes the configurable parameters for debugging. +type Config struct { + Port int `env:"DEBUG_PORT,default=9999"` +} diff --git a/cmdutil/v2/debug/debug.go b/cmdutil/v2/debug/debug.go new file mode 100644 index 00000000..9a995661 --- /dev/null +++ b/cmdutil/v2/debug/debug.go @@ -0,0 +1,74 @@ +// Package debug wraps the gops agent for use as a cmdutil-compatible Server. +// +// The debug server will be started on DEBUG_PORT (default 9999). Get a stack +// trace, profile memory, etc. by running the gops command line connected to +// locahost:9999 like: +// +// $ gops stack localhost:9999 +// goroutine 50 [running]: +// runtime/pprof.writeGoroutineStacks(0x4a18a20, 0xc000010138, 0x0, 0x0) +// /usr/local/Cellar/go/1.13.5/libexec/src/runtime/pprof/pprof.go:679 +0x9d +// runtime/pprof.writeGoroutine(0x4a18a20, 0xc000010138, 0x2, 0x0, 0x0) +// ... +// +// Learn more about gops at https://github.com/google/gops. +package debug + +import ( + "fmt" + + "log/slog" + + "github.com/google/gops/agent" +) + +// New inializes a debug server listening on the provided port. +// +// Connect to the debug server with gops: +// +// gops stack localhost:PORT +func New(l *slog.Logger, port int) *Server { + return &Server{ + logger: l, + addr: fmt.Sprintf("127.0.0.1:%d", port), + done: make(chan struct{}), + } +} + +// Server wraps a gops server for easy use with oklog/group. +type Server struct { + logger *slog.Logger + addr string + done chan struct{} +} + +// Run starts the debug server. +// +// It implements oklog group's runFn. +func (s *Server) Run() error { + s.logger.Info("", + "at", "binding", + "service", "debug", + "addr", s.addr, + ) + + opts := agent.Options{ + Addr: s.addr, + ShutdownCleanup: false, + } + if err := agent.Listen(opts); err != nil { + return err + } + + <-s.done + return nil +} + +// Stop shuts down the debug server. +// +// It implements oklog group's interruptFn. +func (s *Server) Stop(_ error) { + agent.Close() + + close(s.done) +} diff --git a/cmdutil/v2/health/config.go b/cmdutil/v2/health/config.go new file mode 100644 index 00000000..36f3215a --- /dev/null +++ b/cmdutil/v2/health/config.go @@ -0,0 +1,8 @@ +package health + +// Config can be used in a service's main config struct to load the +// healthcheck port from the environment. +type Config struct { + Port int `env:"HEROKU_ROUTER_HEALTHCHECK_PORT,default=6000"` + MetricInterval int `env:"HEROKU_HEALTH_METRIC_INTERVAL,default=5"` +} diff --git a/cmdutil/v2/health/serve.go b/cmdutil/v2/health/serve.go new file mode 100644 index 00000000..ca0647a2 --- /dev/null +++ b/cmdutil/v2/health/serve.go @@ -0,0 +1,47 @@ +// Package health provides cmdutil-compatible healthcheck utilities. +package health + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/heroku/x/cmdutil" + "github.com/heroku/x/go-kit/metrics" + "github.com/heroku/x/tickgroup" + "github.com/heroku/x/v2/healthcheck" +) + +// NewTCPServer returns a cmdutil.Server which emits a health metric whenever a TCP +// connection is opened on the configured port. +func NewTCPServer(logger *slog.Logger, provider metrics.Provider, cfg Config) cmdutil.Server { + healthLogger := logger.With(slog.String("service", "healthcheck")) + healthLogger.With( + slog.String("at", "binding"), + slog.Int("port", cfg.Port), + ).Info("") + + return healthcheck.NewTCPServer(healthLogger, provider, fmt.Sprintf(":%d", cfg.Port)) +} + +// NewTickingServer returns a cmdutil.Server which emits a health metric every +// cfg.MetricInterval seconds. +func NewTickingServer(logger *slog.Logger, provider metrics.Provider, cfg Config) cmdutil.Server { + logger.With( + slog.String("service", "healthcheck-worker"), + slog.String("at", "starting"), + slog.Int("interval", cfg.MetricInterval), + ).Info("") + + c := provider.NewCounter("health") + + return cmdutil.NewContextServer(func(ctx context.Context) error { + g := tickgroup.New(ctx) + g.Go(time.Duration(cfg.MetricInterval)*time.Second, func() error { + c.Add(1) + return nil + }) + return g.Wait() + }) +} diff --git a/cmdutil/v2/service/config.go b/cmdutil/v2/service/config.go new file mode 100644 index 00000000..29b11017 --- /dev/null +++ b/cmdutil/v2/service/config.go @@ -0,0 +1,90 @@ +package service + +import ( + "net/url" + "time" + + "github.com/heroku/x/cmdutil/debug" + "github.com/heroku/x/cmdutil/metrics" + "github.com/heroku/x/cmdutil/oc" + "github.com/heroku/x/cmdutil/rollbar" + "github.com/heroku/x/cmdutil/v2/svclog" +) + +// standardConfig is used when service.New is called. +type standardConfig struct { + Debug debug.Config + Logger svclog.Config + Metrics metrics.Config + Rollbar rollbar.Config + OpenCensus oc.Config +} + +// platformConfig is used by HTTP and captures +// config related to running on the Heroku platform. +type platformConfig struct { + // Port is the primary port to listen on when running as a normal platform + // app. + Port int `env:"PORT"` + + // AdditionalPort defines an additional port to listen on in addition to the + // primary port for use with dyno-dyno networking. + AdditionalPort int `env:"ADDITIONAL_PORT"` +} + +// bypassConfig is used by HTTP and GRPC and captures +// config related to running with the router bypass +// feature on the Heroku platform. +type bypassConfig struct { + // The following ports, TLS, and ACM configurations are set when running with + // spaces-router-bypass enabled. + InsecurePort int `env:"HEROKU_ROUTER_HTTP_PORT"` + SecurePort int `env:"HEROKU_ROUTER_HTTPS_PORT"` + HealthPort int `env:"HEROKU_ROUTER_HEALTHCHECK_PORT"` + TLS tlsConfig + ACMEHTTPValidationURL *url.URL `env:"ACME_HTTP_VALIDATION_URL,default=https://va-acm.runtime.herokai.com/challenge"` +} + +// tlsConfig is used by bypassConfig and captures config related to TLS +// when not running on the Heroku platform. +type tlsConfig struct { + // These environement variables are automatically set by Foundation in + // relation to Let's Encrypt certificates. + ServerCert string `env:"SERVER_CERT"` + ServerKey string `env:"SERVER_KEY"` + + // Used by GRPC services, set by terraform. + ServerCACert string `env:"SERVER_CA_CERT"` + + UseAutocert bool `env:"HTTPS_USE_AUTOCERT"` +} + +// spaceCAConfig is used by grpcConfig and captures config related to +// common runtime services whose certs are generated using the +// spaceCA. +type spaceCAConfig struct { + // Used by GRPC services in new mTLS cert generation where services + // generate their certificates using the SpaceCA. + RootCACert string `env:"HEROKU_SPACE_CA_ROOT_CERT"` + SpaceCACert string `env:"HEROKU_SPACE_CA_CERT"` + SpaceCAKey string `env:"HEROKU_SPACE_CA_KEY"` + + // RootCACertAlternate is set during a root certificate rotation and must be + // installed into the certificate pool to ensure that services are able to + // communicate while the rotation is in progress. + RootCACertAlternate string `env:"HEROKU_SPACE_CA_ROOT_CERT_ALTERNATE"` + + // Switch which will determine whether an app generates their cert + // using the SpaceCA. + UseSpaceCA bool `env:"USE_SPACE_CA,default=false"` + + // Domain of the service used in the generation of the cert + Domain string `env:"DOMAIN"` +} + +type timeoutConfig struct { + Read time.Duration `env:"SERVER_READ_TIMEOUT"` + ReadHeader time.Duration `env:"SERVER_READ_HEADER_TIMEOUT,default=30s"` + Write time.Duration `env:"SERVER_WRITE_TIMEOUT"` + Idle time.Duration `env:"SERVER_IDLE_TIMEOUT"` +} diff --git a/cmdutil/v2/service/config_test.go b/cmdutil/v2/service/config_test.go new file mode 100644 index 00000000..301751d2 --- /dev/null +++ b/cmdutil/v2/service/config_test.go @@ -0,0 +1,35 @@ +package service + +import ( + "testing" + + "github.com/joeshaw/envdecode" +) + +// httpConfig should be decodable with nothing required. +// +// This isn't a perfect test as there may be something set +// in the test environment that is used by httpConfig but it +// should help ensure at least more specific items like +// SERVER_CA_CERT are not required. +func TestDecodeHTTPConfig(t *testing.T) { + var cfg httpConfig + + if err := envdecode.StrictDecode(&cfg); err != nil { + t.Fatal(err) + } +} + +// grpcConfig should be decodable with nothing required. +// +// This isn't a perfect test as there may be something set +// in the test environment that is used by grpcConfig but it +// should help ensure at least more specific items like +// SERVER_CA_CERT are not required. +func TestDecodeGRPCConfig(t *testing.T) { + var cfg grpcConfig + + if err := envdecode.StrictDecode(&cfg); err != nil { + t.Fatal(err) + } +} diff --git a/cmdutil/v2/service/doc.go b/cmdutil/v2/service/doc.go new file mode 100644 index 00000000..9ad537a9 --- /dev/null +++ b/cmdutil/v2/service/doc.go @@ -0,0 +1,3 @@ +// Package service provides standardized command and HTTP setup by smartly +// composing the other cmdutil packages based on environment variables. +package service diff --git a/cmdutil/v2/service/http_example_test.go b/cmdutil/v2/service/http_example_test.go new file mode 100644 index 00000000..7dfe1667 --- /dev/null +++ b/cmdutil/v2/service/http_example_test.go @@ -0,0 +1,30 @@ +package service_test + +import ( + "io" + "net/http" + "time" + + "github.com/heroku/x/cmdutil/v2/service" +) + +func ExampleWithHTTPServerHook() { + var cfg struct { + Hello string `env:"HELLO,default=hello"` + } + svc := service.New(&cfg) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, cfg.Hello) + }) + + configureHTTP := func(s *http.Server) { + s.ReadTimeout = 10 * time.Second + } + + svc.Add(service.HTTP(svc.Logger, svc.MetricsProvider, handler, + service.WithHTTPServerHook(configureHTTP), + )) + + svc.Run() +} diff --git a/cmdutil/v2/service/integration_test.go b/cmdutil/v2/service/integration_test.go new file mode 100644 index 00000000..abffa926 --- /dev/null +++ b/cmdutil/v2/service/integration_test.go @@ -0,0 +1,38 @@ +//go:build integration +// +build integration + +package service + +import ( + "os" + "testing" + + "github.com/heroku/x/cmdutil" + "github.com/heroku/x/go-kit/metrics/l2met" +) + +func TestPanicReporting(t *testing.T) { + + os.Setenv("APP_NAME", "test-app") + os.Setenv("DEPLOY", "test") + + t.Cleanup(func() { + os.Unsetenv("APP_NAME") + os.Unsetenv("DEPLOY") + }) + + var cfg struct { + Val string `env:"TEST_VAL,default=test"` + } + + s := New(&cfg) + + f := func() error { + panic("test panic") + return nil + } + + s.Add(cmdutil.ServerFunc(f)) + s.Run() + +} diff --git a/cmdutil/v2/service/standard.go b/cmdutil/v2/service/standard.go new file mode 100644 index 00000000..7c22e8b2 --- /dev/null +++ b/cmdutil/v2/service/standard.go @@ -0,0 +1,143 @@ +package service + +import ( + "os" + "strings" + "syscall" + + "log/slog" + + "github.com/joeshaw/envdecode" + "github.com/oklog/run" + + "github.com/heroku/x/cmdutil" + "github.com/heroku/x/cmdutil/metrics" + "github.com/heroku/x/cmdutil/v2/debug" + "github.com/heroku/x/cmdutil/v2/signals" + "github.com/heroku/x/cmdutil/v2/svclog" + xmetrics "github.com/heroku/x/go-kit/metrics" +) + +// Standard is a standard service. +type Standard struct { + g run.Group + + App string + Deploy string + Logger *slog.Logger + MetricsProvider xmetrics.Provider +} + +// New Standard Service with logging, rollbar, metrics, debugging, common signal +// handling, and possibly more. +// +// If appConfig is non-nil, envdecode.MustStrictDecode will be called on it +// to ensure that it is processed. +func New(appConfig interface{}, ofs ...OptionFunc) *Standard { + var sc standardConfig + envdecode.MustStrictDecode(&sc) + + if appConfig != nil { + envdecode.MustStrictDecode(appConfig) + } + + logger := svclog.NewLogger(sc.Logger) + + // TODO: Add rollbar support. + var o options + for _, of := range ofs { + of(&o) + } + + if !o.skipMetricsSuffix && sc.Metrics.Prefix != "" { + suf := o.customMetricsSuffix + if suf == "" { + suf = metricsSuffixFromDyno(sc.Logger.Dyno) + } + if suf != "" { + sc.Metrics.Prefix += "." + suf + } + } + + s := &Standard{ + App: sc.Logger.AppName, + Deploy: sc.Logger.Deploy, + Logger: logger, + } + + s.Add(debug.New(logger, sc.Debug.Port)) + s.Add(signals.NewServer(logger, syscall.SIGINT, syscall.SIGTERM)) + + return s +} + +// Add adds cmdutil.Servers to be managed. +func (s *Standard) Add(svs ...cmdutil.Server) { + for _, sv := range svs { + sv := sv + runWithPanicReport := func() error { + defer metrics.ReportPanic(s.MetricsProvider) + defer svclog.ReportPanic(s.Logger) + return sv.Run() + } + s.g.Add(runWithPanicReport, sv.Stop) + } +} + +// Run runs all standard and Added cmdutil.Servers. +// +// If a panic is encountered, it is reported to Rollbar. +// +// If the error returned by oklog/run.Run is non-nil, it is logged +// with s.Logger.Fatal. +func (s *Standard) Run() { + err := s.g.Run() + + if s.MetricsProvider != nil { + s.MetricsProvider.Stop() + } + + if err != nil { + s.Logger.Error(err.Error()) + os.Exit(1) + } +} + +type options struct { + customMetricsSuffix string + skipMetricsSuffix bool +} + +// OptionFunc is a function that modifies internal service options. +type OptionFunc func(*options) + +// SkipMetricsSuffix prevents the Service from suffixing the process type to +// metric names recorded by the MetricsProvider. The default suffix is +// determined from $DYNO. +func SkipMetricsSuffix() OptionFunc { + return func(o *options) { + o.skipMetricsSuffix = true + } +} + +// CustomMetricsSuffix to be added to metrics recorded by the MetricsProvider +// instead of inferring it from $DYNO. +func CustomMetricsSuffix(s string) OptionFunc { + return func(o *options) { + o.customMetricsSuffix = s + } +} + +// metricsSuffixFromDyno determines a metrics suffix from the process part of +// $DYNO. If $DYNO indicates a "web" process, the suffix is "server". If $DYNO +// is empty, so is the suffix. +func metricsSuffixFromDyno(dyno string) string { + if dyno == "" { + return dyno + } + parts := strings.SplitN(dyno, ".", 2) + if parts[0] == "web" { + parts[0] = "server" // TODO[freeformz]: Document why this is server + } + return parts[0] +} diff --git a/cmdutil/v2/service/standard_grpc.go b/cmdutil/v2/service/standard_grpc.go new file mode 100644 index 00000000..062d6303 --- /dev/null +++ b/cmdutil/v2/service/standard_grpc.go @@ -0,0 +1,95 @@ +package service + +import ( + "crypto/tls" + "log/slog" + "os" + + "github.com/joeshaw/envdecode" + "github.com/pkg/errors" + + "github.com/heroku/x/cmdutil" + "github.com/heroku/x/cmdutil/spaceca" + "github.com/heroku/x/cmdutil/v2/health" + "github.com/heroku/x/go-kit/metrics" + "github.com/heroku/x/grpc/v2/grpcserver" +) + +type grpcConfig struct { + Bypass bypassConfig + SpaceCA spaceCAConfig +} + +func loadMutualTLSCert(cfg grpcConfig) (tls.Certificate, [][]byte, error) { + if cfg.SpaceCA.UseSpaceCA { + ca := spaceca.CA{ + RootCert: []byte(cfg.SpaceCA.RootCACert), + Cert: []byte(cfg.SpaceCA.SpaceCACert), + Key: []byte(cfg.SpaceCA.SpaceCAKey), + } + domain := cfg.SpaceCA.Domain + + cert, err := spaceca.NewCACertificate(domain, ca) + if err != nil { + return tls.Certificate{}, nil, errors.Wrap(err, "error generating cert from spaceCA") + } + + serverCACertList := [][]byte{ca.RootCert} + + if cfg.SpaceCA.RootCACertAlternate != "" { + serverCACertList = append(serverCACertList, []byte(cfg.SpaceCA.RootCACertAlternate)) + } + + return *cert, serverCACertList, nil + } + + serverCert, err := tls.X509KeyPair([]byte(cfg.Bypass.TLS.ServerCert), []byte(cfg.Bypass.TLS.ServerKey)) + if err != nil { + return tls.Certificate{}, nil, errors.Wrap(err, "creating X509 key pair") + } + serverCACertList := [][]byte{[]byte(cfg.Bypass.TLS.ServerCACert)} + + return serverCert, serverCACertList, nil +} + +// GRPC returns a standard GRPC server for the provided handler. +// Router-bypass and TLS config are inferred from the environment. +// +// Currently only supports running in router-bypass mode, unlike HTTP. +func GRPC( + l *slog.Logger, + m metrics.Provider, + server grpcserver.Starter, + grpcOpts ...grpcserver.ServerOption) cmdutil.Server { + var cfg grpcConfig + envdecode.MustDecode(&cfg) + + cert, serverCACertList, err := loadMutualTLSCert(cfg) + if err != nil { + l.Error(err.Error()) + os.Exit(1) + } + + var srvs []cmdutil.Server + + if cfg.Bypass.SecurePort != 0 { + grpcOpts = append(grpcOpts, grpcserver.MetricsProvider(m)) + + srvs = append(srvs, grpcserver.NewStandardServer( + l, + cfg.Bypass.SecurePort, + serverCACertList, + cert, + server, + grpcOpts..., + )) + } + + if cfg.Bypass.HealthPort != 0 { + srvs = append(srvs, health.NewTCPServer(l, m, health.Config{ + Port: cfg.Bypass.HealthPort, + })) + } + + return cmdutil.MultiServer(srvs...) +} diff --git a/cmdutil/v2/service/standard_http.go b/cmdutil/v2/service/standard_http.go new file mode 100644 index 00000000..e7a74f5e --- /dev/null +++ b/cmdutil/v2/service/standard_http.go @@ -0,0 +1,240 @@ +package service + +import ( + "context" + "crypto/tls" + "fmt" + "log/slog" + "net" + "net/http" + "os" + "time" + + proxyproto "github.com/armon/go-proxyproto" + "github.com/joeshaw/envdecode" + "github.com/pkg/errors" + "golang.org/x/crypto/acme/autocert" + + "github.com/heroku/x/cmdutil" + "github.com/heroku/x/cmdutil/https" + "github.com/heroku/x/cmdutil/v2/health" + "github.com/heroku/x/go-kit/metrics" + "github.com/heroku/x/tlsconfig" +) + +type httpConfig struct { + Platform platformConfig + Bypass bypassConfig + Timeouts timeoutConfig +} + +// HTTP returns a standard HTTP server for the provided handler. Port, TLS, and +// router-bypass config are inferred from the environment. +func HTTP(l *slog.Logger, m metrics.Provider, h http.Handler, opts ...func(*httpOptions)) cmdutil.Server { + var cfg httpConfig + envdecode.MustDecode(&cfg) + + var o httpOptions + for _, opt := range opts { + opt(&o) + } + + if !o.skipEnforceHTTPS { + h = https.RedirectHandler(h) + } + + var srvs []cmdutil.Server + + if cfg.Platform.Port != 0 { + + s := httpServerWithTimeouts(cfg.Timeouts) + s.Handler = h + s.Addr = fmt.Sprintf(":%d", cfg.Platform.Port) + o.configureServer(s) + srvs = append(srvs, standardServer(l, s)) + } + + if cfg.Platform.AdditionalPort != 0 { + s := httpServerWithTimeouts(cfg.Timeouts) + + s.Handler = h + s.Addr = fmt.Sprintf(":%d", cfg.Platform.AdditionalPort) + o.configureServer(s) + srvs = append(srvs, standardServer(l, s)) + } + + if cfg.Bypass.InsecurePort != 0 { + + s := httpServerWithTimeouts(cfg.Timeouts) + s.Handler = h + s.Addr = fmt.Sprintf(":%d", cfg.Bypass.InsecurePort) + + o.configureServer(s) + srvs = append(srvs, bypassServer(l, s)) + } + + if cfg.Bypass.SecurePort != 0 { + tlsConfig := o.tlsConfig + if tlsConfig == nil { + tlsConfig = newTLSConfig(cfg.Bypass.TLS) + } + + s := httpServerWithTimeouts(cfg.Timeouts) + s.Handler = h + s.TLSConfig = tlsConfig + s.Addr = fmt.Sprintf(":%d", cfg.Bypass.SecurePort) + + o.configureServer(s) + srvs = append(srvs, bypassServer(l, s)) + } + + if cfg.Bypass.HealthPort != 0 { + srvs = append(srvs, health.NewTCPServer(l, m, health.Config{ + Port: cfg.Bypass.HealthPort, + })) + } + + return cmdutil.MultiServer(srvs...) +} + +func httpServerWithTimeouts(t timeoutConfig) *http.Server { + return &http.Server{ + ReadTimeout: t.Read, + ReadHeaderTimeout: t.ReadHeader, + WriteTimeout: t.Write, + IdleTimeout: t.Idle, + } +} + +type httpOptions struct { + skipEnforceHTTPS bool + tlsConfig *tls.Config + serverHook func(*http.Server) +} + +func (o *httpOptions) configureServer(s *http.Server) { + if o.serverHook != nil { + o.serverHook(s) + } +} + +// SkipEnforceHTTPS allows services to opt-out of SSL enforcement required for +// productionization. It should only be used in environments where SSL is not +// available. +func SkipEnforceHTTPS() func(*httpOptions) { + return func(o *httpOptions) { + o.skipEnforceHTTPS = true + } +} + +// WithHTTPServerHook allows services to provide a function to +// adjust settings on any HTTP server before after the defaults are +// applied but before the server is started. +func WithHTTPServerHook(fn func(*http.Server)) func(*httpOptions) { + return func(o *httpOptions) { + o.serverHook = fn + } +} + +// WithTLSConfig allows services to use a specific TLS configuration instead of +// the default one constructed from environment variables. +func WithTLSConfig(tlscfg *tls.Config) func(*httpOptions) { + return func(o *httpOptions) { + o.tlsConfig = tlscfg + } +} + +func newTLSConfig(cfg tlsConfig) *tls.Config { + var ( + serverCert = []byte(cfg.ServerCert) + serverKey = []byte(cfg.ServerKey) + ) + + tlsConfig := tlsconfig.New() + + if cfg.UseAutocert { + am := &autocert.Manager{ + Prompt: autocert.AcceptTOS, + } + tlsConfig.GetCertificate = am.GetCertificate + } else { + cert, err := tls.X509KeyPair(serverCert, serverKey) + if err != nil { + slog.Error("unable to load TLS config", slog.String("error", err.Error())) + os.Exit(1) + } + + tlsConfig.Certificates = []tls.Certificate{cert} + } + + return tlsConfig +} + +// listenHook allows tests to intercept the listener created for standard and +// bypass servers, e.g., to get the resolved address when the server's Addr is +// `:0`. +var listenHook chan net.Listener + +// standardServer adapts an http.Server to a cmdutil.Server. The server is expected +// to be run behind a router and does not terminate TLS. +func standardServer(l *slog.Logger, srv *http.Server) cmdutil.Server { + return cmdutil.ServerFuncs{ + RunFunc: func() error { + l.Info("", slog.String("at", "binding"), slog.String("addr", srv.Addr)) + + ln, err := net.Listen("tcp", srv.Addr) + if err != nil { + return err + } + defer ln.Close() + + if listenHook != nil { + listenHook <- ln + } + + return srv.Serve(ln) + }, + StopFunc: func(error) { gracefulShutdown(l, srv) }, + } +} + +// bypassServer adapts an http.Server to a cmdutil.Server. The server is expected +// to be directly behind an ELB and uses proxyprotocol. It terminates TLS if +// TLSConfig is set on srv. +func bypassServer(l *slog.Logger, srv *http.Server) cmdutil.Server { + return cmdutil.ServerFuncs{ + RunFunc: func() error { + l.Info("", slog.String("at", "binding"), slog.String("addr", srv.Addr)) + + ln, err := net.Listen("tcp", srv.Addr) + if err != nil { + return errors.Wrap(err, "listening to tcp addr") + } + defer ln.Close() + + if listenHook != nil { + listenHook <- ln + } + + ln = &proxyproto.Listener{Listener: ln} + + if srv.TLSConfig != nil { + return srv.ServeTLS(ln, "", "") + } + + return srv.Serve(ln) + }, + StopFunc: func(error) { gracefulShutdown(l, srv) }, + } +} + +func gracefulShutdown(l *slog.Logger, s *http.Server) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + l.Info("", slog.String("at", "graceful-shutdown"), slog.String("addr", s.Addr)) + if err := s.Shutdown(ctx); err != nil { + l.Warn("", slog.String("at", "graceful-shutdown"), slog.String("error", err.Error())) + s.Close() + } +} diff --git a/cmdutil/v2/service/standard_http_test.go b/cmdutil/v2/service/standard_http_test.go new file mode 100644 index 00000000..c557c0cd --- /dev/null +++ b/cmdutil/v2/service/standard_http_test.go @@ -0,0 +1,136 @@ +package service + +import ( + "bufio" + "io" + "net" + "net/http" + "os" + "testing" + + "github.com/heroku/x/go-kit/metrics/testmetrics" + "github.com/heroku/x/testing/v2/testlog" +) + +func TestStandardHTTPServer(t *testing.T) { + l, _ := testlog.New() + //nolint: gosec + srv := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + if _, err := io.WriteString(w, "OK"); err != nil { + t.Error(err) + } + }), + Addr: "127.0.0.1:0", + } + + listenHook = make(chan net.Listener) + defer func() { listenHook = nil }() + + s := standardServer(l, srv) + + done := make(chan struct{}) + go func() { + if err := s.Run(); err != nil { + t.Log(err) + } + close(done) + }() + + addr := (<-listenHook).Addr().String() + + res, err := http.Get("http://" + addr) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + + data, _ := io.ReadAll(res.Body) + if string(data) != "OK" { + t.Fatalf("want OK got %v", string(data)) + } + + s.Stop(nil) + + <-done +} + +func TestBypassHTTPServer(t *testing.T) { + l, _ := testlog.New() + //nolint: gosec + srv := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + if _, err := io.WriteString(w, "OK"); err != nil { + t.Error(err) + } + }), + Addr: "127.0.0.1:0", + } + + listenHook = make(chan net.Listener) + defer func() { listenHook = nil }() + + s := bypassServer(l, srv) + + done := make(chan struct{}) + go func() { + if err := s.Run(); err != nil { + t.Log(err) + } + close(done) + }() + + addr := (<-listenHook).Addr().String() + conn, err := net.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + if _, err = io.WriteString(conn, "PROXY TCP4 127.0.0.1 127.0.0.1 44444 55555\n"); err != nil { + t.Fatal(err) + } + + req, _ := http.NewRequest("GET", "http://"+addr, nil) + if err := req.Write(conn); err != nil { + t.Fatal(err) + } + + r := bufio.NewReader(conn) + res, err := http.ReadResponse(r, nil) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + + data, _ := io.ReadAll(res.Body) + if string(data) != "OK" { + t.Fatalf("want OK got %v", string(data)) + } + + s.Stop(nil) + + <-done +} + +func TestHTTPServerConfiguration(t *testing.T) { + os.Setenv("PORT", "1234") + os.Setenv("ADDITIONAL_PORT", "4567") + defer func() { + os.Unsetenv("PORT") + os.Unsetenv("ADDITIONAL_PORT") + }() + + var configuredServers []string + config := func(s *http.Server) { + configuredServers = append(configuredServers, s.Addr) + } + + l, _ := testlog.New() + p := testmetrics.NewProvider(t) + HTTP(l, p, nil, WithHTTPServerHook(config)) + + if len(configuredServers) != 2 { + t.Fatalf("expected 2 servers to be configured, got %v", configuredServers) + } +} diff --git a/cmdutil/v2/service/standard_test.go b/cmdutil/v2/service/standard_test.go new file mode 100644 index 00000000..1301c8d1 --- /dev/null +++ b/cmdutil/v2/service/standard_test.go @@ -0,0 +1,57 @@ +package service_test + +import ( + "os" + "testing" + "time" + + "github.com/heroku/x/cmdutil/service" +) + +func TestNewNoConfig(t *testing.T) { + setupStandardConfig(t) + + s := service.New(nil) + + if s.Logger == nil { + t.Fatal("standard logger not configured") + } + + if s.MetricsProvider == nil { + t.Fatal("standard metrics provider not configured") + } +} + +func TestNewCustomConfig(t *testing.T) { + setupStandardConfig(t) + + os.Setenv("TEST_VAL", "1m") + defer os.Unsetenv("TEST_VAL") + + var cfg struct { + Val time.Duration `env:"TEST_VAL"` + } + s := service.New(&cfg) + + if s.Logger == nil { + t.Fatal("standard logger not configured") + } + + if s.MetricsProvider == nil { + t.Fatal("standard metrics provider not configured") + } + + if cfg.Val != time.Minute { + t.Fatalf("cfg.Val = %v want %v", cfg.Val, time.Minute) + } +} + +func setupStandardConfig(t *testing.T) { + os.Setenv("APP_NAME", "test-app") + os.Setenv("DEPLOY", "test") + + t.Cleanup(func() { + os.Unsetenv("APP_NAME") + os.Unsetenv("DEPLOY") + }) +} diff --git a/cmdutil/v2/signals/signals.go b/cmdutil/v2/signals/signals.go new file mode 100644 index 00000000..d9ac97c2 --- /dev/null +++ b/cmdutil/v2/signals/signals.go @@ -0,0 +1,55 @@ +// Package signals provides a signal handler which is usable as a cmdutil.Server. +package signals + +import ( + "context" + "log/slog" + "os" + "os/signal" + + "github.com/heroku/x/cmdutil" +) + +// WithNotifyCancel creates a sub-context from the given context which gets +// canceled upon receiving any of the configured signals. +func WithNotifyCancel(ctx context.Context, signals ...os.Signal) context.Context { + notified := make(chan os.Signal, 1) + return notifyContext(ctx, notified, signals...) +} + +func notifyContext(ctx context.Context, notified chan os.Signal, signals ...os.Signal) context.Context { + ctx, cancel := context.WithCancel(ctx) + signal.Notify(notified, signals...) + + go func() { + <-notified + cancel() + }() + + return ctx +} + +// NewServer returns a cmdutil.Server that returns from Run +// when any of the provided signals are received. +// Run always returns a nil error. +func NewServer(logger *slog.Logger, signals ...os.Signal) cmdutil.Server { + ch := make(chan os.Signal, 1) + + return cmdutil.ServerFuncs{ + RunFunc: func() error { + signal.Notify(ch, signals...) + sig := <-ch + if sig != nil { + logger.Info("received signal", "sig", sig) + } + return nil + }, + StopFunc: func(error) { + signal.Stop(ch) + select { + case ch <- nil: + default: + } + }, + } +} diff --git a/cmdutil/v2/signals/signals_test.go b/cmdutil/v2/signals/signals_test.go new file mode 100644 index 00000000..909afbad --- /dev/null +++ b/cmdutil/v2/signals/signals_test.go @@ -0,0 +1,92 @@ +package signals + +import ( + "context" + "os" + "syscall" + "testing" + "time" + + "github.com/heroku/x/testing/v2/testlog" +) + +func TestWithNotifyCancel(t *testing.T) { + notified := make(chan os.Signal, 1) + ctx := notifyContext(context.Background(), notified, syscall.SIGINT) + + notified <- syscall.SIGINT + select { + case <-ctx.Done(): + case <-time.After(time.Second): + t.Fatalf("expected ctx to be canceled") + } +} + +func TestNewServer(t *testing.T) { + logger, _ := testlog.New() + + sv := NewServer(logger, syscall.SIGWINCH) + + var ( + runErr error + runDone = make(chan struct{}) + done = make(chan struct{}) + ) + defer close(done) + + go func() { + runErr = sv.Run() + close(runDone) + }() + + // We're racing with Run starting and calling signal.Notify, so loop + // it until the test is done. + go func() { + for { + select { + case <-done: + return + default: + } + if err := syscall.Kill(syscall.Getpid(), syscall.SIGWINCH); err != nil { + t.Error(err) + } + time.Sleep(time.Millisecond) + } + }() + + select { + case <-runDone: + case <-time.After(500 * time.Millisecond): + t.Fatal("Run took too long") + } + + if runErr != nil { + t.Fatalf("got Run error %+v, want no error", runErr) + } + + sv.Stop(nil) +} + +// Ensure Run returns when Stop is called, even if no signal +// has been received. +func TestNewServerNoSignal(t *testing.T) { + logger, _ := testlog.New() + + sv := NewServer(logger, syscall.SIGWINCH) + + var runErr error + done := make(chan struct{}) + + go func() { + runErr = sv.Run() + close(done) + }() + + sv.Stop(nil) + <-done + + if runErr != nil { + t.Fatalf("got Run error %+v, want no error", runErr) + } +} diff --git a/cmdutil/v2/svclog/logger.go b/cmdutil/v2/svclog/logger.go new file mode 100644 index 00000000..3ded974d --- /dev/null +++ b/cmdutil/v2/svclog/logger.go @@ -0,0 +1,68 @@ +// Package svclog provides logging facilities for standard services. +package svclog + +import ( + "fmt" + "io" + "log" + "os" + + "log/slog" +) + +// Config for logger. +type Config struct { + AppName string `env:"APP_NAME,required"` + Deploy string `env:"DEPLOY,required"` + SpaceID string `env:"SPACE_ID"` + Dyno string `env:"DYNO"` + LogLevel string `env:"LOG_LEVEL,default=INFO"` + + WriteTo io.Writer +} + +// NewLogger returns a new logger that includes app and deploy key/value pairs +// in each log line. +func NewLogger(cfg Config) *slog.Logger { + level, err := ParseLevel(cfg.LogLevel) + if err != nil { + log.Fatal(err) + } + + hopts := &slog.HandlerOptions{ + Level: level, + } + var w io.Writer + w = cfg.WriteTo + if w == nil { + w = os.Stdout + } + logger := slog.New(slog.NewTextHandler(w, hopts)).With( + "app", cfg.AppName, + "deploy", cfg.Deploy, + ) + + if cfg.SpaceID != "" { + logger = logger.With("space", cfg.SpaceID) + } + if cfg.Dyno != "" { + logger = logger.With("dyno", cfg.Dyno) + } + + return logger +} + +// ReportPanic attempts to report the panic to rollbar via the slog. +func ReportPanic(logger *slog.Logger) { + if p := recover(); p != nil { + s := fmt.Sprint(p) + logger.With("at", "panic").Error(s) + panic(p) + } +} + +func ParseLevel(s string) (slog.Level, error) { + var level slog.Level + var err = level.UnmarshalText([]byte(s)) + return level, err +} diff --git a/cmdutil/v2/svclog/logger_test.go b/cmdutil/v2/svclog/logger_test.go new file mode 100644 index 00000000..893d2396 --- /dev/null +++ b/cmdutil/v2/svclog/logger_test.go @@ -0,0 +1,58 @@ +package svclog + +import ( + "bytes" + "testing" + + "github.com/heroku/x/testing/v2/testlog" +) + +func TestLoggerEmitsAppAndDeployData(t *testing.T) { + buf := bytes.NewBuffer([]byte{}) + cfg := Config{ + AppName: "sushi", + Deploy: "production", + LogLevel: "INFO", + Dyno: "web.1", + WriteTo: buf, + } + logger := NewLogger(cfg) + logger.Info("message") + + testlog.ExpectLogLineFromBuffer(t, buf, "", map[string]interface{}{ + "app": "sushi", + "deploy": "production", + "msg": "message", + "dyno": "web.1", + }) +} + +func TestReportPanic(t *testing.T) { + buf := bytes.NewBuffer([]byte{}) + cfg := Config{ + AppName: "sushi", + Deploy: "production", + LogLevel: "INFO", + Dyno: "web.1", + WriteTo: buf, + } + logger := NewLogger(cfg) + + defer func() { + if p := recover(); p == nil { + t.Fatal("expected ReportPanic to repanic") + } + + testlog.ExpectLogLineFromBuffer(t, buf, "", map[string]interface{}{ + "msg": "\"test message\"", + "at": "panic", + "level": "ERROR", + }) + }() + + func() { + defer ReportPanic(logger) + + panic("test message") + }() +} diff --git a/go.mod b/go.mod index fe637a43..20d3be16 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,7 @@ require ( github.com/golang/protobuf v1.5.3 github.com/gomodule/redigo v1.8.9 github.com/google/gops v0.3.22 - github.com/google/uuid v1.3.1 + github.com/google/uuid v1.4.0 github.com/grpc-ecosystem/go-grpc-middleware v1.2.0 github.com/grpc-ecosystem/grpc-gateway v1.16.0 github.com/heroku/rollrus v0.2.0 @@ -46,12 +46,12 @@ require ( go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v0.42.0 go.opentelemetry.io/otel/metric v1.27.0 golang.org/x/crypto v0.21.0 - golang.org/x/sync v0.3.0 + golang.org/x/sync v0.7.0 golang.org/x/sys v0.20.0 // indirect golang.org/x/time v0.0.0-20181108054448-85acf8d2951c - google.golang.org/grpc v1.59.0 + google.golang.org/grpc v1.61.1 google.golang.org/grpc/examples v0.0.0-20210916203835-567da6b86340 - google.golang.org/protobuf v1.31.0 + google.golang.org/protobuf v1.32.0 gopkg.in/ini.v1 v1.42.0 // indirect gopkg.in/square/go-jose.v2 v2.5.1 ) @@ -69,15 +69,17 @@ require ( github.com/go-logr/logr v1.4.1 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6 // indirect + github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.1.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.0 // indirect github.com/konsorten/go-windows-terminal-sequences v1.0.3 // indirect go.opentelemetry.io/otel/trace v1.27.0 // indirect go.opentelemetry.io/proto/otlp v1.0.0 // indirect + golang.org/x/exp v0.0.0-20240716175740-e3f259677ff7 // indirect golang.org/x/net v0.23.0 // indirect - golang.org/x/oauth2 v0.12.0 // indirect + golang.org/x/oauth2 v0.16.0 // indirect google.golang.org/api v0.7.0 // indirect - google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto v0.0.0-20231012201019-e917dd12ba7a // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20231016165738-49dd2c1f3d0b // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b // indirect + google.golang.org/appengine v1.6.8 // indirect + google.golang.org/genproto v0.0.0-20240205150955-31a09d347014 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240125205218-1f4bbc51befe // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240213162025-012b6fc9bca9 // indirect ) diff --git a/go.sum b/go.sum index dca84a48..4359c716 100644 --- a/go.sum +++ b/go.sum @@ -90,6 +90,7 @@ github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QD github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/gomodule/redigo v1.8.9 h1:Sl3u+2BI/kk+VEatbj0scLdrFhjPmbxOc1myhDP41ws= @@ -110,11 +111,15 @@ github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OI github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= +github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/grpc-ecosystem/go-grpc-middleware v1.2.0 h1:0IKlLyQ3Hs9nDaiK5cSHAGmcQEIC8l2Ts1u6x5Dfrqg= github.com/grpc-ecosystem/go-grpc-middleware v1.2.0/go.mod h1:mJzapYve32yjrKlk9GbyCZHuPgZsrbyIbyKhSzOpg6s= +github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.1.0 h1:pRhl55Yx1eC7BZ1N+BBWwnKaMyD8uC+34TLdndZMAKk= +github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.1.0/go.mod h1:XKMd7iuf/RGPSMJ/U4HP0zS2Z9Fh8Ps9a+6X26m/tmI= github.com/grpc-ecosystem/grpc-gateway v1.9.4/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= @@ -203,6 +208,7 @@ github.com/urfave/cli/v2 v2.2.0/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2 github.com/xlab/treeprint v1.1.0/go.mod h1:gj5Gd3gPdKtR1ikdDK6fnFLdmIS0X30kTTuNd/WEJu0= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.1 h1:8dP3SGL7MPB94crU3bEPplMPe83FI4EouesJUeFHv50= @@ -233,9 +239,12 @@ go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20240716175740-e3f259677ff7 h1:wDLEX9a7YQoKdKNQt88rtydkqDxeGaBUTnIYc3iG/mA= +golang.org/x/exp v0.0.0-20240716175740-e3f259677ff7/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= @@ -244,6 +253,7 @@ golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHl golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -261,6 +271,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200505041828-1ed23360d12c/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -269,6 +281,8 @@ golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4Iltr golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.12.0 h1:smVPGxink+n1ZI5pkQa8y6fZT0RW0MgCO5bFpepy4B4= golang.org/x/oauth2 v0.12.0/go.mod h1:A74bZ3aGXgCY0qaIC9Ahg6Lglin4AMAco8cIv9baba4= +golang.org/x/oauth2 v0.16.0 h1:aDkGMBSYxElaoP81NpoUoz2oo2R2wHdZpGToUxfyQrQ= +golang.org/x/oauth2 v0.16.0/go.mod h1:hqZ+0LWXsiVoZpeld6jVt06P3adbS2Uu911W1SsJv2o= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -276,8 +290,11 @@ golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -290,14 +307,22 @@ golang.org/x/sys v0.0.0-20190804053845-51ab0e2deafa/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210816074244-15123e1e1f71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210902050250-f475640dd07b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c h1:fqgJT0MGcGpPgpWU7VRdRjuArfcOvC4AoJmILihzhDg= @@ -315,6 +340,7 @@ golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -327,6 +353,8 @@ google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7 google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM= +google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20181221175505-bd9b4fb69e2f/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= @@ -340,10 +368,16 @@ google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEY google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20231012201019-e917dd12ba7a h1:fwgW9j3vHirt4ObdHoYNwuO24BEZjSzbh+zPaNWoiY8= google.golang.org/genproto v0.0.0-20231012201019-e917dd12ba7a/go.mod h1:EMfReVxb80Dq1hhioy0sOsY9jCE46YDgHlJ7fWVUWRE= +google.golang.org/genproto v0.0.0-20240205150955-31a09d347014 h1:g/4bk7P6TPMkAUbUhquq98xey1slwvuVJPosdBqYJlU= +google.golang.org/genproto v0.0.0-20240205150955-31a09d347014/go.mod h1:xEgQu1e4stdSSsxPDK8Azkrk/ECl5HvdPf6nbZrTS5M= google.golang.org/genproto/googleapis/api v0.0.0-20231016165738-49dd2c1f3d0b h1:CIC2YMXmIhYw6evmhPxBKJ4fmLbOFtXQN/GV3XOZR8k= google.golang.org/genproto/googleapis/api v0.0.0-20231016165738-49dd2c1f3d0b/go.mod h1:IBQ646DjkDkvUIsVq/cc03FUFQ9wbZu7yE396YcL870= +google.golang.org/genproto/googleapis/api v0.0.0-20240125205218-1f4bbc51befe h1:0poefMBYvYbs7g5UkjS6HcxBPaTRAmznle9jnxYoAI8= +google.golang.org/genproto/googleapis/api v0.0.0-20240125205218-1f4bbc51befe/go.mod h1:4jWUdICTdgc3Ibxmr8nAJiiLHwQBY0UI0XZcEMaFKaA= google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b h1:ZlWIi1wSK56/8hn4QcBp/j9M7Gt3U/3hZw3mC7vDICo= google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b/go.mod h1:swOH3j0KzcDDgGUWr+SNpyTen5YrXjS3eyPzFYKc6lc= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240213162025-012b6fc9bca9 h1:hZB7eLIaYlW9qXRfCq/qDaPdbeY3757uARz5Vvfv+cY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240213162025-012b6fc9bca9/go.mod h1:YUWgXUFRPfoYK1IHMuxH5K6nPEXSCzIMljnQ59lLRCk= google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.0/go.mod h1:chYK+tFQF0nDUGJgXMSgLCQk3phJEuONr2DCgLDdAQM= @@ -356,6 +390,8 @@ google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTp google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.59.0 h1:Z5Iec2pjwb+LEOqzpB2MR12/eKFhDPhuqW91O+4bwUk= google.golang.org/grpc v1.59.0/go.mod h1:aUPDwccQo6OTjy7Hct4AfBPD1GptF4fyUjIkQ9YtF98= +google.golang.org/grpc v1.61.1 h1:kLAiWrZs7YeDM6MumDe7m3y4aM6wacLzM1Y/wiLP9XY= +google.golang.org/grpc v1.61.1/go.mod h1:VUbo7IFqmF1QtCAstipjG0GIoq49KvMe9+h1jFLBNJs= google.golang.org/grpc/examples v0.0.0-20210916203835-567da6b86340 h1:ZST99LW/5hCiDvAXb3KZoOwPz1xXbDfU4Gp1TlmX5l4= google.golang.org/grpc/examples v0.0.0-20210916203835-567da6b86340/go.mod h1:gID3PKrg7pWKntu9Ss6zTLJ0ttC0X9IHgREOCZwbCVU= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= @@ -372,6 +408,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= +google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/ini.v1 v1.42.0 h1:7N3gPTt50s8GuLortA00n8AqRTk75qOP98+mTPpgzRk= diff --git a/grpc/v2/grpcserver/codes.go b/grpc/v2/grpcserver/codes.go new file mode 100644 index 00000000..93a0d090 --- /dev/null +++ b/grpc/v2/grpcserver/codes.go @@ -0,0 +1,27 @@ +package grpcserver + +import ( + "context" + + "github.com/pkg/errors" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" +) + +// ErrorToCode determines the gRPC error code for an error, accounting for +// context errors and errors wrapped with pkg/errors. +// +// ErrorToCode implements grpc_logging.ErrorToCode. +func ErrorToCode(err error) codes.Code { + err = errors.Cause(err) + + switch err { + case context.Canceled: + return codes.Canceled + case context.DeadlineExceeded: + return codes.DeadlineExceeded + default: + //TODO: SA1019: grpc.Code is deprecated: use status.Code instead. (staticcheck) + return grpc.Code(err) //nolint:staticcheck + } +} diff --git a/grpc/v2/grpcserver/codes_test.go b/grpc/v2/grpcserver/codes_test.go new file mode 100644 index 00000000..e21bb7d0 --- /dev/null +++ b/grpc/v2/grpcserver/codes_test.go @@ -0,0 +1,48 @@ +package grpcserver + +import ( + "context" + "testing" + + "github.com/pkg/errors" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestErrorToCode_Unknown(t *testing.T) { + code := ErrorToCode(errors.New("other")) + if code != codes.Unknown { + t.Fatalf("code = %v, want %v", code, codes.Unknown) + } +} + +func TestErrorToCode_GRPC(t *testing.T) { + err := status.Errorf(codes.NotFound, "not found") + code := ErrorToCode(err) + if code != codes.NotFound { + t.Fatalf("code = %v, want %v", code, codes.NotFound) + } +} + +func TestErrorToCode_Wrapped(t *testing.T) { + err := status.Errorf(codes.NotFound, "not found") + code := ErrorToCode(errors.WithStack(err)) + if code != codes.NotFound { + t.Fatalf("code = %v, want %v", code, codes.NotFound) + } +} + +func TestErrorToCode_Canceled(t *testing.T) { + code := ErrorToCode(context.Canceled) + if code != codes.Canceled { + t.Fatalf("code = %v, want %v", code, codes.Canceled) + } +} + +func TestErrorToCode_DeadlineExceeded(t *testing.T) { + code := ErrorToCode(context.DeadlineExceeded) + if code != codes.DeadlineExceeded { + t.Fatalf("code = %v, want %v", code, codes.DeadlineExceeded) + } +} diff --git a/grpc/v2/grpcserver/inprocess.go b/grpc/v2/grpcserver/inprocess.go new file mode 100644 index 00000000..4d084053 --- /dev/null +++ b/grpc/v2/grpcserver/inprocess.go @@ -0,0 +1,52 @@ +package grpcserver + +import ( + "net" + "time" + + "github.com/hydrogen18/memlistener" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +// Local returns an in-process server for the provided gRPC server. +func Local(s *grpc.Server) *LocalServer { + return &LocalServer{ + ln: memlistener.NewMemoryListener(), + srv: s, + } +} + +// An LocalServer serves a gRPC server from memory. +type LocalServer struct { + ln *memlistener.MemoryListener + srv *grpc.Server +} + +// Run starts the in-process server. +// +// It implements oklog group's runFn. +func (s *LocalServer) Run() error { + return s.srv.Serve(s.ln) +} + +// Stop gracefully stops the gRPC server. +// +// It implements oklog group's interruptFn. +func (s *LocalServer) Stop(err error) { + s.srv.GracefulStop() +} + +// Conn returns a client connection to the in-process server. +func (s *LocalServer) Conn(opts ...grpc.DialOption) *grpc.ClientConn { + defaultOptions := []grpc.DialOption{ + // TODO: SA1019: grpc.WithDialer is deprecated: use WithContextDialer instead (staticcheck) + grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) { //nolint:staticcheck + return s.ln.Dial("mem", "") + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + } + + conn, _ := grpc.Dial("", append(defaultOptions, opts...)...) + return conn +} diff --git a/grpc/v2/grpcserver/options.go b/grpc/v2/grpcserver/options.go new file mode 100644 index 00000000..51df934a --- /dev/null +++ b/grpc/v2/grpcserver/options.go @@ -0,0 +1,248 @@ +package grpcserver + +import ( + "context" + "crypto/tls" + "crypto/x509" + "log/slog" + "os" + "time" + + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" + grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" + grpc_validator "github.com/grpc-ecosystem/go-grpc-middleware/validator" + "go.opencensus.io/plugin/ocgrpc" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + + "github.com/heroku/x/go-kit/metrics" + "github.com/heroku/x/grpc/grpcmetrics" + "github.com/heroku/x/grpc/v2/panichandler" + "github.com/heroku/x/tlsconfig" +) + +const ( + defaultReadHeaderTimeout = 60 * time.Second +) + +var defaultLogOpts = []logging.Option{ + logging.WithCodes(ErrorToCode), +} + +type options struct { + logger *slog.Logger + metricsProvider metrics.Provider + authUnaryInterceptor grpc.UnaryServerInterceptor + authStreamInterceptor grpc.StreamServerInterceptor + highCardUnaryInterceptor grpc.UnaryServerInterceptor + highCardStreamInterceptor grpc.StreamServerInterceptor + readHeaderTimeout time.Duration + + useValidateInterceptor bool + + grpcOptions []grpc.ServerOption +} + +func defaultOptions() options { + return options{ + readHeaderTimeout: defaultReadHeaderTimeout, + } +} + +// ServerOption sets optional fields on the standard gRPC server +type ServerOption func(*options) + +// GRPCOption adds a grpc ServerOption to the server. +func GRPCOption(opt grpc.ServerOption) ServerOption { + return func(o *options) { + o.grpcOptions = append(o.grpcOptions, opt) + } +} + +// Logger provided will be added to the context +func Logger(l *slog.Logger) ServerOption { + return func(o *options) { + o.logger = l + } +} + +// MetricsProvider will have metrics reported to it +func MetricsProvider(provider metrics.Provider) ServerOption { + return func(o *options) { + o.metricsProvider = provider + } +} + +// AuthInterceptors sets interceptors that are intended for +// authentication/authorization in the correct locations in the chain +func AuthInterceptors(unary grpc.UnaryServerInterceptor, stream grpc.StreamServerInterceptor) ServerOption { + return func(o *options) { + o.authUnaryInterceptor = unary + o.authStreamInterceptor = stream + } +} + +// HighCardInterceptors sets interceptors that use +// Attributes/Labels on the instrumentation. +func HighCardInterceptors(unary grpc.UnaryServerInterceptor, stream grpc.StreamServerInterceptor) ServerOption { + return func(o *options) { + o.highCardUnaryInterceptor = unary + o.highCardStreamInterceptor = stream + } +} + +// WithOCGRPCServerHandler sets the grpc server up with provided ServerHandler +// as its StatsHandler +func WithOCGRPCServerHandler(h *ocgrpc.ServerHandler) ServerOption { + return func(o *options) { + o.grpcOptions = append(o.grpcOptions, grpc.StatsHandler(h)) + } +} + +func WithReadHeaderTimeout(d time.Duration) ServerOption { + return func(o *options) { + o.readHeaderTimeout = d + } +} + +// ValidateInterceptor sets interceptors that will validate every +// message that has a receiver of the form `Validate() error` +// +// See github.com/mwitkow/go-proto-validators for details. +func ValidateInterceptor() ServerOption { + return func(o *options) { + o.useValidateInterceptor = true + } +} + +// InterceptorLogger adapts slog logger to interceptor logger. +// This code is simple enough to be copied and not imported. +// See https://github.com/grpc-ecosystem/go-grpc-middleware/blob/62b7de50cda5a5d633f1013bfbe50e0f38db34ef/interceptors/logging/examples/slog/example_test.go#17 +func InterceptorLogger(l *slog.Logger) logging.Logger { + return logging.LoggerFunc(func(ctx context.Context, lvl logging.Level, msg string, fields ...any) { + l.Log(ctx, slog.Level(lvl), msg, fields...) + }) +} + +func (o *options) unaryInterceptors() []grpc.UnaryServerInterceptor { + l := o.logger + if l == nil { + l = slog.New(slog.NewTextHandler(os.Stdout, nil)) + } + + i := []grpc.UnaryServerInterceptor{ + panichandler.LoggingUnaryPanicHandler(l), + grpc_ctxtags.UnaryServerInterceptor(), + UnaryPayloadLoggingTagger, + unaryRequestIDTagger, + unaryPeerNameTagger, + } + + if o.highCardUnaryInterceptor != nil { + i = append(i, o.highCardUnaryInterceptor) + } else if o.metricsProvider != nil { + i = append(i, grpcmetrics.NewUnaryServerInterceptor(o.metricsProvider)) // report metrics on unwrapped errors + } + + i = append(i, + unaryServerErrorUnwrapper, // unwrap after we've logged + logging.UnaryServerInterceptor(InterceptorLogger(l), defaultLogOpts...), + ) + if o.authUnaryInterceptor != nil { + i = append(i, o.authUnaryInterceptor) + } + if o.useValidateInterceptor { + i = append(i, grpc_validator.UnaryServerInterceptor()) + } + + return i +} + +func (o *options) streamInterceptors() []grpc.StreamServerInterceptor { + l := o.logger + if l == nil { + l = slog.New(slog.NewTextHandler(os.Stdout, nil)) + } + + i := []grpc.StreamServerInterceptor{ + panichandler.LoggingStreamPanicHandler(l), + grpc_ctxtags.StreamServerInterceptor(), + streamRequestIDTagger, + streamPeerNameTagger, + } + + if o.highCardStreamInterceptor != nil { + i = append(i, o.highCardStreamInterceptor) + } else if o.metricsProvider != nil { + i = append(i, grpcmetrics.NewStreamServerInterceptor(o.metricsProvider)) // report metrics on unwrapped errors + } + + i = append(i, + streamServerErrorUnwrapper, // unwrap after we've logged + logging.StreamServerInterceptor(InterceptorLogger(l), defaultLogOpts...), + ) + if o.authStreamInterceptor != nil { + i = append(i, o.authStreamInterceptor) + } + if o.useValidateInterceptor { + i = append(i, grpc_validator.StreamServerInterceptor()) + } + + return i +} + +func (o *options) serverOptions() []grpc.ServerOption { + opts := []grpc.ServerOption{ + grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(o.unaryInterceptors()...)), + grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(o.streamInterceptors()...)), + } + opts = append(opts, o.grpcOptions...) + return opts +} + +// TLS returns a ServerOption which adds mutual-TLS to the gRPC server. +func TLS(caCerts [][]byte, serverCert tls.Certificate) (ServerOption, error) { + tlsConfig, err := tlsconfig.NewMutualTLS(caCerts, serverCert) + if err != nil { + return nil, err + } + + return GRPCOption(grpc.Creds(credentials.NewTLS(tlsConfig))), nil +} + +// WithPeerValidator configures the gRPC server to reject calls from peers +// which do not provide a certificate or for which the provided function +// returns false. +func WithPeerValidator(f func(*x509.Certificate) bool) ServerOption { + return func(o *options) { + o.authStreamInterceptor = func(req interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if err := validatePeer(ss.Context(), f); err != nil { + return err + } + return handler(req, ss) + } + o.authUnaryInterceptor = func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) { + if err := validatePeer(ctx, f); err != nil { + return nil, err + } + return handler(ctx, req) + } + } +} + +func validatePeer(ctx context.Context, f func(*x509.Certificate) bool) error { + cert, ok := getPeerCertFromContext(ctx) + if !ok { + // TODO: SA1019: grpc.Errorf is deprecated: use status.Errorf instead. (staticcheck) + return grpc.Errorf(codes.Unauthenticated, "unauthenticated") //nolint:staticcheck + } + + if !f(cert) { + // TODO: SA1019: grpc.Errorf is deprecated: use status.Errorf instead. (staticcheck) + return grpc.Errorf(codes.PermissionDenied, "forbidden") //nolint:staticcheck + } + + return nil +} diff --git a/grpc/v2/grpcserver/payload.go b/grpc/v2/grpcserver/payload.go new file mode 100644 index 00000000..9a733909 --- /dev/null +++ b/grpc/v2/grpcserver/payload.go @@ -0,0 +1,53 @@ +package grpcserver + +import ( + "context" + "fmt" + + grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" + "google.golang.org/grpc" +) + +// UnaryPayloadLoggingTagger annotates ctx with grpc_ctxtags tags for request and +// response payloads. +// +// A loggable request or response implements this interface +// +// type loggable interface { +// LoggingTags() map[string]interface{} +// } +// +// Any request or response implementing this interface will add tags to the +// context for logging in success and error cases. +func UnaryPayloadLoggingTagger(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) { + tag(ctx, "request", req) + + resp, err := handler(ctx, req) + if err == nil { + tag(ctx, "response", resp) + } + + return resp, err +} + +type loggable interface { + LoggingTags() map[string]interface{} +} + +func tag(ctx context.Context, scope string, pb interface{}) { + tags := grpc_ctxtags.Extract(ctx) + extractTags(tags, scope, pb) +} + +func extractTags(tags grpc_ctxtags.Tags, scope string, pb interface{}) { + if lg, ok := pb.(loggable); ok { + for k, v := range lg.LoggingTags() { + name := fmt.Sprintf("%s.%s", scope, k) + if _, ok := v.(loggable); ok { + extractTags(tags, name, v) + } else { + tags.Set(name, v) + } + } + } +} diff --git a/grpc/v2/grpcserver/payload_test.go b/grpc/v2/grpcserver/payload_test.go new file mode 100644 index 00000000..ca50a3d2 --- /dev/null +++ b/grpc/v2/grpcserver/payload_test.go @@ -0,0 +1,79 @@ +package grpcserver + +import ( + "reflect" + "testing" + + grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" +) + +func TestExtractTagsWithoutLoggingTagsCompatibleValue(t *testing.T) { + tags := newTags() + extractTags(tags, "scope", "value") + got := tags.Values() + want := make(map[string]interface{}) + if !reflect.DeepEqual(got, want) { + t.Fatalf("got %#v, want %#v", got, want) + } +} + +func TestExtractTags(t *testing.T) { + tags := newTags() + extractTags(tags, "scope", &value{}) + got := tags.Values() + want := make(map[string]interface{}) + want["scope.value"] = "hello" + if !reflect.DeepEqual(got, want) { + t.Fatalf("got %#v, want %#v", got, want) + } +} + +func TestExtractTagsWithNestedValue(t *testing.T) { + tags := newTags() + extractTags(tags, "scope", &nestedValue{}) + got := tags.Values() + want := make(map[string]interface{}) + want["scope.nested.value"] = "hello" + if !reflect.DeepEqual(got, want) { + t.Fatalf("got %#v, want %#v", got, want) + } +} + +type value struct{} + +func (v *value) LoggingTags() map[string]interface{} { + res := make(map[string]interface{}) + res["value"] = "hello" + return res +} + +type nestedValue struct{} + +func (v *nestedValue) LoggingTags() map[string]interface{} { + res := make(map[string]interface{}) + res["nested"] = &value{} + return res +} + +// testTags mirrors the implementation of grpc_ctxtags.Tags +type testTags struct { + values map[string]interface{} +} + +func (t *testTags) Set(key string, value interface{}) grpc_ctxtags.Tags { + t.values[key] = value + return t +} + +func (t *testTags) Has(key string) bool { + _, ok := t.values[key] + return ok +} + +func (t *testTags) Values() map[string]interface{} { + return t.values +} + +func newTags() grpc_ctxtags.Tags { + return &testTags{values: make(map[string]interface{})} +} diff --git a/grpc/v2/grpcserver/server.go b/grpc/v2/grpcserver/server.go new file mode 100644 index 00000000..ec305ecf --- /dev/null +++ b/grpc/v2/grpcserver/server.go @@ -0,0 +1,194 @@ +package grpcserver + +import ( + "context" + "crypto/tls" + "crypto/x509" + "log/slog" + "net" + "net/http" + "os" + "strconv" + + grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" + "github.com/lstoll/grpce/h2c" + "github.com/pkg/errors" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + healthgrpc "google.golang.org/grpc/health" + healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/peer" + + "github.com/heroku/x/cmdutil" + "github.com/heroku/x/grpc/requestid" +) + +// New configures a gRPC Server with default options and a health server. +func New(opts ...ServerOption) *grpc.Server { + var o options + for _, so := range opts { + so(&o) + } + + srv := grpc.NewServer(o.serverOptions()...) + + healthpb.RegisterHealthServer(srv, healthgrpc.NewServer()) + + return srv +} + +// A Starter registers and starts itself on the provided grpc.Server. +// +// It's expected Start will call the relevant RegisterXXXServer method +// using srv. +type Starter interface { + Start(srv *grpc.Server) error +} + +// RunStandardServer runs a GRPC server with a standard setup including metrics +// (if provider passed), panic handling, a health check service, TLS termination +// with client authentication, and proxy-protocol wrapping. +// +// Deprecated: Use NewStandardServer instead. +func RunStandardServer(logger *slog.Logger, port int, serverCACerts [][]byte, serverCert, serverKey []byte, server Starter, opts ...ServerOption) error { + cert, err := tls.X509KeyPair(serverCert, serverKey) + if err != nil { + return errors.Wrap(err, "creating X509 key pair") + } + + return NewStandardServer(logger, port, serverCACerts, cert, server, opts...).Run() +} + +// NewStandardServer configures a GRPC server with a standard setup including metrics +// (if provider passed), panic handling, a health check service, TLS termination +// with client authentication, and proxy-protocol wrapping. +func NewStandardServer(logger *slog.Logger, port int, serverCACerts [][]byte, serverCert tls.Certificate, server Starter, opts ...ServerOption) cmdutil.Server { + tls, err := TLS(serverCACerts, serverCert) + if err != nil { + logger.Error(err.Error()) + os.Exit(1) + } + + opts = append(opts, tls, Logger(logger.With(slog.String("component", "grpc")))) + grpcsrv := New(opts...) + + if err := server.Start(grpcsrv); err != nil { + logger.Error(err.Error()) + os.Exit(1) + } + + return TCP(logger, grpcsrv, net.JoinHostPort("", strconv.Itoa(port))) +} + +// NewStandardH2C create a set of servers suitable for serving gRPC services +// using H2C (aka client upgrades). This is suitable for serving gRPC services +// via both hermes and dogwood-router. HTTP 1.x traffic will be passed to the +// provided handler. This will return a *grpc.Server configured with our +// standard set of services, and a HTTP server that should be what is served on +// a listener. +func NewStandardH2C(http11 http.Handler, opts ...ServerOption) (*grpc.Server, *http.Server) { + o := defaultOptions() + for _, so := range opts { + so(&o) + } + + gSrv := grpc.NewServer(o.serverOptions()...) + + healthpb.RegisterHealthServer(gSrv, healthgrpc.NewServer()) + + h2cSrv := &h2c.Server{ + HTTP2Handler: gSrv, + NonUpgradeHandler: http11, + } + + hSrv := &http.Server{ + Handler: h2cSrv, + ReadHeaderTimeout: o.readHeaderTimeout, + } + + return gSrv, hSrv +} + +// unaryServerErrorUnwrapper removes errors.Wrap annotations from errors so +// gRPC status codes are correctly returned to interceptors and clients later +// in the chain. +func unaryServerErrorUnwrapper(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) { + res, err := handler(ctx, req) + return res, errors.Cause(err) +} + +// streamServerErrorUnwrapper removes errors.Wrap annotations from errors so +// gRPC status codes are correctly returned to interceptors and clients later +// in the chain. +func streamServerErrorUnwrapper(srv interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + err := handler(srv, ss) + return errors.Cause(err) +} + +// unaryRequestIDTagger sets a grpc_ctxtags request_id tag for logging if the +// context includes a request ID. +func unaryRequestIDTagger(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) { + if id, ok := requestid.FromContext(ctx); ok { + grpc_ctxtags.Extract(ctx).Set("request_id", id) + } + + return handler(ctx, req) +} + +// streamRequestIDTagger sets a grpc_ctxtags request_id tag for logging if the +// context includes a request ID. +func streamRequestIDTagger(req interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if id, ok := requestid.FromContext(ss.Context()); ok { + grpc_ctxtags.Extract(ss.Context()).Set("request_id", id) + } + + return handler(req, ss) +} + +// unaryPeerNameTagger sets a grpc_ctxtags peer name tag for logging if the +// caller provider provides a mutual TLS certificate. +func unaryPeerNameTagger(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) { + peerName := getPeerNameFromContext(ctx) + if peerName != "" { + grpc_ctxtags.Extract(ctx).Set("peer.name", peerName) + } + + return handler(ctx, req) +} + +// streamPeerNameTagger sets a grpc_ctxtags peer name tag for logging if the +// caller provider provides a mutual TLS certificate. +func streamPeerNameTagger(req interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + peerName := getPeerNameFromContext(ss.Context()) + if peerName != "" { + grpc_ctxtags.Extract(ss.Context()).Set("peer.name", peerName) + } + + return handler(req, ss) +} + +func getPeerNameFromContext(ctx context.Context) string { + cert, ok := getPeerCertFromContext(ctx) + if !ok { + return "" + } + return cert.Subject.CommonName +} + +func getPeerCertFromContext(ctx context.Context) (*x509.Certificate, bool) { + p, ok := peer.FromContext(ctx) + if !ok { + return nil, false + } + + tlsAuth, ok := p.AuthInfo.(credentials.TLSInfo) + if !ok { + return nil, false + } + + if len(tlsAuth.State.PeerCertificates) == 0 { + return nil, false + } + + return tlsAuth.State.PeerCertificates[0], true +} diff --git a/grpc/v2/grpcserver/server_test.go b/grpc/v2/grpcserver/server_test.go new file mode 100644 index 00000000..4dfbed4f --- /dev/null +++ b/grpc/v2/grpcserver/server_test.go @@ -0,0 +1,125 @@ +package grpcserver + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "testing" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" + + "github.com/heroku/x/testing/mustcert" +) + +func ExampleLocal() { + srv := New() + localsrv := Local(srv) + + go func() { + if err := localsrv.Run(); err != nil { + panic(err) + } + }() + defer localsrv.Stop(nil) + + c := healthpb.NewHealthClient(localsrv.Conn()) + + //TODO: SA1019: grpc.FailFast is deprecated: use WaitForReady. (staticcheck) + resp, err := c.Check(context.Background(), &healthpb.HealthCheckRequest{}, grpc.FailFast(true)) //nolint:staticcheck + if err != nil { + fmt.Printf("Error = %v", err) + return + } + + fmt.Printf("Status = %v", resp.Status) +} + +func TestGetPeerNameFromContext(t *testing.T) { + t.Run("empty context", func(t *testing.T) { + if name := getPeerNameFromContext(context.Background()); name != "" { + t.Errorf("name = %q want %q", name, "") + } + }) + + t.Run("non-mTLS peer", func(t *testing.T) { + ctx := peer.NewContext(context.Background(), &peer.Peer{ + // AuthInfo is nil if there is no transport security, based on the peer + // package's docs. + AuthInfo: nil, + }) + + if name := getPeerNameFromContext(ctx); name != "" { + t.Errorf("name = %q want %q", name, "") + } + }) + + t.Run("with an mTLS peer", func(t *testing.T) { + clientName := "client" + clientCert := mustcert.Leaf(clientName, nil) + + ctx := peer.NewContext(context.Background(), &peer.Peer{ + AuthInfo: credentials.TLSInfo{ + State: tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{ + clientCert.TLS().Leaf, + }, + }, + }, + }) + + if name := getPeerNameFromContext(ctx); name != clientName { + t.Errorf("name = %q want %q", name, clientName) + } + }) +} + +func TestValidatePeer(t *testing.T) { + clientName := "client" + clientCert := mustcert.Leaf(clientName, nil) + validPeer := &peer.Peer{ + AuthInfo: credentials.TLSInfo{ + State: tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{ + clientCert.TLS().Leaf, + }, + }, + }, + } + + t.Run("non-mTLS peer", func(t *testing.T) { + ctx := peer.NewContext(context.Background(), &peer.Peer{ + // AuthInfo is nil if there is no transport security, based on the peer + // package's docs. + AuthInfo: nil, + }) + + err := validatePeer(ctx, func(*x509.Certificate) bool { return false }) + if status.Code(err) != codes.Unauthenticated { + t.Fatalf("err = %+v want %v", err, codes.Unauthenticated) + } + }) + + t.Run("valid peer rejected by validator", func(t *testing.T) { + ctx := peer.NewContext(context.Background(), validPeer) + + err := validatePeer(ctx, func(*x509.Certificate) bool { return false }) + if status.Code(err) != codes.PermissionDenied { + t.Fatalf("err = %+v want %v", err, codes.PermissionDenied) + } + }) + + t.Run("valid peer accepted by validator", func(t *testing.T) { + ctx := peer.NewContext(context.Background(), validPeer) + + err := validatePeer(ctx, func(*x509.Certificate) bool { return true }) + if err != nil { + t.Fatalf("err = %+v want nil", err) + } + }) +} diff --git a/grpc/v2/grpcserver/tcp.go b/grpc/v2/grpcserver/tcp.go new file mode 100644 index 00000000..d3058e29 --- /dev/null +++ b/grpc/v2/grpcserver/tcp.go @@ -0,0 +1,53 @@ +package grpcserver + +import ( + "log/slog" + "net" + + proxyproto "github.com/armon/go-proxyproto" + "google.golang.org/grpc" +) + +// TCP returns a TCP server for the provided gRPC server. +// +// The server transparently handles proxy protocol. +func TCP(l *slog.Logger, s *grpc.Server, addr string) *TCPServer { + return &TCPServer{ + logger: l, + srv: s, + addr: addr, + } +} + +// A TCPServer serves a gRPC server over TCP with proxy-protocol support. +type TCPServer struct { + logger *slog.Logger + srv *grpc.Server + addr string +} + +// Run binds to the configured address and serves the gRPC server. +// +// It implements oklog group's runFn. +func (s *TCPServer) Run() error { + ln, err := net.Listen("tcp", s.addr) + if err != nil { + return err + } + proxyprotoLn := &proxyproto.Listener{Listener: ln} + + s.logger.With( + slog.String("at", "binding"), + slog.String("service", "grpc-tcp"), + slog.String("addr", ln.Addr().String()), + ).Info("") + + return s.srv.Serve(proxyprotoLn) +} + +// Stop gracefully stops the gRPC server. +// +// It implements oklog group's interruptFn. +func (s *TCPServer) Stop(err error) { + s.srv.GracefulStop() +} diff --git a/grpc/v2/panichandler/panichandler.go b/grpc/v2/panichandler/panichandler.go new file mode 100644 index 00000000..01a7adda --- /dev/null +++ b/grpc/v2/panichandler/panichandler.go @@ -0,0 +1,49 @@ +package panichandler + +import ( + "context" + "log/slog" + + "github.com/pkg/errors" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" +) + +// LoggingUnaryPanicHandler returns a server interceptor which recovers +// panics, logs them as errors with logger, and returns a gRPC internal +// error to clients. +func LoggingUnaryPanicHandler(logger *slog.Logger) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + defer handleCrash(func(r interface{}) { + werr := errors.Errorf("grpc unary server panic: %v", r) + logger.Error("grpc unary server panic", slog.String("error", werr.Error())) + err = toPanicError(werr) + }) + return handler(ctx, req) + } +} + +// LoggingStreamPanicHandler returns a stream server interceptor which +// recovers panics, logs them as errors with logger, and returns a +// gRPC internal error to clients. +func LoggingStreamPanicHandler(logger *slog.Logger) grpc.StreamServerInterceptor { + return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { + defer handleCrash(func(r interface{}) { + werr := errors.Errorf("grpc stream server panic: %v", r) + logger.Error("grpc stream server panic", slog.String("error", werr.Error())) + err = toPanicError(werr) + }) + return handler(srv, stream) + } +} + +func handleCrash(handler func(interface{})) { + if r := recover(); r != nil { + handler(r) + } +} + +func toPanicError(r interface{}) error { + //TODO: SA1019: grpc.Errorf is deprecated: use status.Errorf instead. (staticcheck) + return grpc.Errorf(codes.Internal, "panic: %v", r) //nolint:staticcheck +} diff --git a/grpc/v2/panichandler/panichandler_test.go b/grpc/v2/panichandler/panichandler_test.go new file mode 100644 index 00000000..5c2826bf --- /dev/null +++ b/grpc/v2/panichandler/panichandler_test.go @@ -0,0 +1,139 @@ +package panichandler + +import ( + "context" + "errors" + "testing" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/heroku/x/testing/v2/testlog" +) + +func TestLoggingUnaryPanicHandler_NoPanic(t *testing.T) { + l, hook := testlog.New() + + var ( + uhCalled bool + res = 1 + testErr = errors.New("test error") + ) + + uh := func(ctx context.Context, req interface{}) (interface{}, error) { + uhCalled = true + return res, testErr + } + + ph := LoggingUnaryPanicHandler(l) + gres, gerr := ph(context.Background(), nil, nil, uh) + + if !uhCalled { + t.Fatal("uh not called") + } + + if gres != res { + t.Fatalf("got res %+v, want %+v", gres, res) + } + + if gerr != testErr { + t.Fatalf("got err %+v, want %+v", gerr, testErr) + } + + if !hook.IsEmpty() { + t.Fatal("got log lines wanted nothing logged") + } +} + +func TestLoggingUnaryPanicHandler_Panic(t *testing.T) { + l, hook := testlog.New() + + var ( + uhCalled bool + res = 1 + testErr = errors.New("test error") + ) + + uh := func(ctx context.Context, req interface{}) (interface{}, error) { + uhCalled = true + if uhCalled { + panic("BOOM") + } + return res, testErr + } + + ph := LoggingUnaryPanicHandler(l) + _, gerr := ph(context.Background(), nil, nil, uh) + + if !uhCalled { + t.Fatal("unary handler not called") + } + + st, ok := status.FromError(gerr) + if !ok || st.Code() != codes.Internal { + t.Fatalf("Got %+v want Internal grpc error", gerr) + } + + hook.ExpectAllContain(t, "grpc unary server panic") +} + +func TestLoggingStreamPanicHandler_NoPanic(t *testing.T) { + l, hook := testlog.New() + + var ( + shCalled bool + testErr = errors.New("test error") + ) + + sh := func(srv interface{}, stream grpc.ServerStream) error { + shCalled = true + return testErr + } + + ph := LoggingStreamPanicHandler(l) + gerr := ph(context.Background(), nil, nil, sh) + + if !shCalled { + t.Fatal("stream handler not called") + } + + if gerr != testErr { + t.Fatalf("got err %+v, want %+v", gerr, testErr) + } + + if !hook.IsEmpty() { + t.Fatal("got log lines wanted nothing logged") + } +} + +func TestLoggingStreamPanicHandler_Panic(t *testing.T) { + l, hook := testlog.New() + + var ( + shCalled bool + testErr = errors.New("test error") + ) + + sh := func(srv interface{}, stream grpc.ServerStream) error { + shCalled = true + if shCalled { + panic("BOOM") + } + return testErr + } + + ph := LoggingStreamPanicHandler(l) + gerr := ph(context.Background(), nil, nil, sh) + + if !shCalled { + t.Fatal("stream handler not called") + } + + st, ok := status.FromError(gerr) + if !ok || st.Code() != codes.Internal { + t.Fatalf("Got %+v want Internal grpc error", gerr) + } + + hook.ExpectAllContain(t, "grpc stream server panic") +} diff --git a/testing/v2/testlog/testlog.go b/testing/v2/testlog/testlog.go new file mode 100644 index 00000000..2e4126fe --- /dev/null +++ b/testing/v2/testlog/testlog.go @@ -0,0 +1,76 @@ +package testlog + +import ( + "bufio" + "bytes" + "fmt" + "log/slog" + "strings" + "testing" +) + +// Hook is used for validating logs. +type Hook struct { + buf *bytes.Buffer +} + +// New returns a new logger and hook suitable for testing. +func New() (*slog.Logger, *Hook) { + hook := &Hook{ + buf: bytes.NewBuffer([]byte{}), + } + + hopts := &slog.HandlerOptions{ + Level: slog.LevelInfo, + } + logger := slog.New(slog.NewTextHandler(hook.buf, hopts)).With( + slog.String("app", "test-app"), + slog.String("deploy", "local"), + slog.String("dyno", "web.1"), + ) + + return logger, hook +} + +// IsEmpty returns true if there no logs have been written to the hook. +func (hook *Hook) IsEmpty() bool { + return hook.buf.Len() == 0 +} + +// ExpectAllContain validates that all log lines contain this substring. +func (hook *Hook) ExpectAllContain(t *testing.T, msg string) { + scanner := bufio.NewScanner(hook.buf) + for scanner.Scan() { + if s := scanner.Text(); !strings.Contains(s, msg) { + t.Errorf("expected log line '%s' to contain '%s'", s, msg) + } + } + + if err := scanner.Err(); err != nil { + t.Fatal(err) + } +} + +// ExpectLogLine uses the hook to validate that +// the next log line contains the passed message and set of key-values in the passed map. +func (hook *Hook) ExpectLogLine(t *testing.T, msg string, m map[string]interface{}) { + ExpectLogLineFromBuffer(t, hook.buf, msg, m) +} + +// ExpectLogLineFromBuffer is the same as hook.ExpectLogLine but instead validates lines from a buffer. +func ExpectLogLineFromBuffer(t *testing.T, b *bytes.Buffer, msg string, m map[string]interface{}) { + line, err := b.ReadString('\n') + if err != nil { + t.Fatal(err) + } + + if !strings.Contains(line, msg) { + t.Errorf("expected log line to contain message: %s", msg) + } + + for k, v := range m { + if !strings.Contains(line, fmt.Sprintf("%s=%s", k, v)) { + t.Errorf("expected log line to contain %s=%s", k, v) + } + } +} diff --git a/v2/healthcheck/tcpserver.go b/v2/healthcheck/tcpserver.go new file mode 100644 index 00000000..846e081c --- /dev/null +++ b/v2/healthcheck/tcpserver.go @@ -0,0 +1,98 @@ +package healthcheck + +import ( + "fmt" + "log/slog" + "net" + "time" + + "github.com/go-kit/kit/metrics" + + hmetrics "github.com/heroku/x/go-kit/metrics" +) + +// TCPServer answers healthcheck requests from TCP routers, such as an ELB. +type TCPServer struct { + logger *slog.Logger + addr string + ln net.Listener + counter metrics.Counter +} + +// NewTCPServer initializes a new health-check server. +func NewTCPServer(logger *slog.Logger, provider hmetrics.Provider, addr string) *TCPServer { + return &TCPServer{ + logger: logger, + counter: provider.NewCounter("health"), + addr: addr, + } +} + +// Run listens on the configured address and responds to healthcheck requests +// from TCP routers, such as an ELB. +func (s *TCPServer) Run() error { + if err := s.start(); err != nil { + return err + } + + return s.serve() +} + +// Stop shuts down the TCPServer if it was already started. +// +// Stop implements the kit.Server interface. +func (s *TCPServer) Stop(error) { + if s.ln != nil { + s.ln.Close() + } +} + +func (s *TCPServer) start() error { + s.logger.With( + slog.String("at", "bind"), + slog.String("addr", s.addr), + ).Info("") + + ln, err := net.Listen("tcp", s.addr) + if err != nil { + return err + } + + s.ln = ln + return nil +} + +func (s *TCPServer) serve() error { + const retryDelay = 50 * time.Millisecond + + for { + conn, err := s.ln.Accept() + if err != nil { + if e, ok := err.(net.Error); ok && e.Timeout() { + s.logger.With( + slog.String("at", "accept"), + slog.String("error", err.Error()), + ).Error(fmt.Sprintf("retrying in %s", retryDelay)) + + time.Sleep(retryDelay) + continue + } + + return err + } + + go s.serveConn(conn) + } +} + +func (s *TCPServer) serveConn(conn net.Conn) { + defer conn.Close() + + s.counter.Add(1) + + if _, err := conn.Write([]byte("OK\n")); err != nil { + s.logger.With( + slog.String("error", err.Error()), + ).Error("") + } +} diff --git a/v2/healthcheck/tcpserver_test.go b/v2/healthcheck/tcpserver_test.go new file mode 100644 index 00000000..1dc0f258 --- /dev/null +++ b/v2/healthcheck/tcpserver_test.go @@ -0,0 +1,54 @@ +package healthcheck + +import ( + "io" + "net" + "testing" + "time" + + "github.com/heroku/x/go-kit/metrics/testmetrics" + "github.com/heroku/x/testing/v2/testlog" +) + +func TestTCPServer(t *testing.T) { + logger, _ := testlog.New() + provider := testmetrics.NewProvider(t) + server := NewTCPServer(logger, provider, "127.0.0.1:0") + + if err := server.start(); err != nil { + t.Fatal("unexpected error", err) + } + + done := make(chan struct{}) + go func() { + defer close(done) + if err := server.serve(); err == nil { + panic("expected error, but got nil") // accept error + } + }() + + conn, err := net.DialTimeout("tcp", server.ln.Addr().String(), time.Second) + if err != nil { + t.Fatalf("unable to dial server: %s", err) + } + defer conn.Close() + + if err := conn.SetDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatal("unexpected error", err) + } + + data, err := io.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + + if got, want := string(data), "OK\n"; got != want { + t.Fatalf("response was %q, want %q", got, want) + } + + // Assert server shuts down after stopping + server.Stop(nil) + <-done + + provider.CheckCounter("health", 1) +}