diff --git a/cmd/remote-storage/app/server.go b/cmd/remote-storage/app/server.go index 3cd1fdd7384..6a3cfb61c6e 100644 --- a/cmd/remote-storage/app/server.go +++ b/cmd/remote-storage/app/server.go @@ -15,6 +15,7 @@ package app import ( + "fmt" "net" "go.uber.org/zap" @@ -104,7 +105,7 @@ func createGRPCServer(opts *Options, tm *tenancy.TenancyManager, handler *shared if opts.TLSGRPC.Enabled { tlsCfg, err := opts.TLSGRPC.Config(logger) if err != nil { - return nil, err + return nil, fmt.Errorf("invalid TLS config: %w", err) } creds := credentials.NewTLS(tlsCfg) grpcOpts = append(grpcOpts, grpc.Creds(creds)) @@ -138,11 +139,10 @@ func (s *Server) Start() error { if err := s.initListener(); err != nil { return err } + s.logger.Info("Starting GRPC server", zap.Stringer("addr", s.grpcConn.Addr())) go func() { - s.logger.Info("Starting GRPC server", zap.String("addr", s.opts.GRPCHostPort)) - if err := s.grpcServer.Serve(s.grpcConn); err != nil { - s.logger.Error("Could not start GRPC server", zap.Error(err)) + s.logger.Error("GRPC server exited", zap.Error(err)) } s.unavailableChannel <- healthcheck.Unavailable }() diff --git a/cmd/remote-storage/app/server_test.go b/cmd/remote-storage/app/server_test.go new file mode 100644 index 00000000000..72a7e8ce0ed --- /dev/null +++ b/cmd/remote-storage/app/server_test.go @@ -0,0 +1,580 @@ +// Copyright (c) 2022 The Jaeger Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package app + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zaptest/observer" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + + "github.com/jaegertracing/jaeger/cmd/flags" + "github.com/jaegertracing/jaeger/internal/grpctest" + "github.com/jaegertracing/jaeger/pkg/config/tlscfg" + "github.com/jaegertracing/jaeger/pkg/healthcheck" + "github.com/jaegertracing/jaeger/pkg/tenancy" + "github.com/jaegertracing/jaeger/ports" + "github.com/jaegertracing/jaeger/proto-gen/storage_v1" + depStoreMocks "github.com/jaegertracing/jaeger/storage/dependencystore/mocks" + factoryMocks "github.com/jaegertracing/jaeger/storage/mocks" + spanStoreMocks "github.com/jaegertracing/jaeger/storage/spanstore/mocks" +) + +var testCertKeyLocation = "../../../pkg/config/tlscfg/testdata" + +func TestNewServer_CreateStorageErrors(t *testing.T) { + factory := new(factoryMocks.Factory) + factory.On("CreateSpanReader").Return(nil, errors.New("no reader")).Once() + factory.On("CreateSpanReader").Return(nil, nil) + factory.On("CreateSpanWriter").Return(nil, errors.New("no writer")).Once() + factory.On("CreateSpanWriter").Return(nil, nil) + factory.On("CreateDependencyReader").Return(nil, errors.New("no deps")).Once() + factory.On("CreateDependencyReader").Return(nil, nil) + + f := func() (*Server, error) { + return NewServer( + &Options{GRPCHostPort: ":0"}, + tenancy.NewTenancyManager(&tenancy.Options{}), + factory, + zap.NewNop(), + ) + } + _, err := f() + require.Error(t, err) + assert.Contains(t, err.Error(), "no reader") + + _, err = f() + require.Error(t, err) + assert.Contains(t, err.Error(), "no writer") + + _, err = f() + require.Error(t, err) + assert.Contains(t, err.Error(), "no deps") + + s, err := f() + require.NoError(t, err) + err = s.Start() + require.NoError(t, err) + validateGRPCServer(t, s.grpcConn.Addr().String(), s.grpcServer) + + s.grpcConn.Close() // causes logged error + <-s.HealthCheckStatus() +} + +func TestServerStart_BadPortErrors(t *testing.T) { + srv := &Server{ + opts: &Options{ + GRPCHostPort: ":-1", + }, + } + assert.Error(t, srv.Start()) +} + +type storageMocks struct { + factory *factoryMocks.Factory + reader *spanStoreMocks.Reader + writer *spanStoreMocks.Writer + depReader *depStoreMocks.Reader +} + +func newStorageMocks() *storageMocks { + reader := new(spanStoreMocks.Reader) + writer := new(spanStoreMocks.Writer) + depReader := new(depStoreMocks.Reader) + + factory := new(factoryMocks.Factory) + factory.On("CreateSpanReader").Return(reader, nil) + factory.On("CreateSpanWriter").Return(writer, nil) + factory.On("CreateDependencyReader").Return(depReader, nil) + + return &storageMocks{ + factory: factory, + reader: reader, + writer: writer, + depReader: depReader, + } +} + +func TestNewServer_TLSConfigError(t *testing.T) { + tlsCfg := tlscfg.Options{ + Enabled: true, + CertPath: "invalid/path", + KeyPath: "invalid/path", + ClientCAPath: "invalid/path", + } + storageMocks := newStorageMocks() + _, err := NewServer( + &Options{GRPCHostPort: ":8081", TLSGRPC: tlsCfg}, + tenancy.NewTenancyManager(&tenancy.Options{}), + storageMocks.factory, + zap.NewNop(), + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid TLS config") +} + +func TestCreateGRPCHandler(t *testing.T) { + storageMocks := newStorageMocks() + h, err := createGRPCHandler(storageMocks.factory, zap.NewNop()) + require.NoError(t, err) + + storageMocks.writer.On("WriteSpan", mock.Anything, mock.Anything).Return(errors.New("writer error")) + _, err = h.WriteSpan(context.Background(), &storage_v1.WriteSpanRequest{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "writer error") + + storageMocks.depReader.On("GetDependencies", mock.Anything, mock.Anything).Return(nil, errors.New("deps error")) + _, err = h.GetDependencies(context.Background(), &storage_v1.GetDependenciesRequest{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "deps error") + + err = h.GetArchiveTrace(nil, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "not implemented") + + _, err = h.WriteArchiveSpan(context.Background(), nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "not implemented") + + err = h.WriteSpanStream(nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "not implemented") +} + +var testCases = []struct { + name string + TLS tlscfg.Options + clientTLS tlscfg.Options + expectError bool + expectClientError bool + expectServerFail bool +}{ + { + // this is a cross test for the "dedicated ports" use case without TLS + name: "should pass with insecure connection", + TLS: tlscfg.Options{ + Enabled: false, + }, + clientTLS: tlscfg.Options{ + Enabled: false, + }, + expectError: false, + expectClientError: false, + expectServerFail: false, + }, + { + name: "should fail with TLS client to untrusted TLS server", + TLS: tlscfg.Options{ + Enabled: true, + CertPath: testCertKeyLocation + "/example-server-cert.pem", + KeyPath: testCertKeyLocation + "/example-server-key.pem", + }, + clientTLS: tlscfg.Options{ + Enabled: true, + ServerName: "example.com", + }, + expectError: true, + expectClientError: true, + expectServerFail: false, + }, + { + name: "should fail with TLS client to trusted TLS server with incorrect hostname", + TLS: tlscfg.Options{ + Enabled: true, + CertPath: testCertKeyLocation + "/example-server-cert.pem", + KeyPath: testCertKeyLocation + "/example-server-key.pem", + }, + clientTLS: tlscfg.Options{ + Enabled: true, + CAPath: testCertKeyLocation + "/example-CA-cert.pem", + ServerName: "nonEmpty", + }, + expectError: true, + expectClientError: true, + expectServerFail: false, + }, + { + name: "should pass with TLS client to trusted TLS server with correct hostname", + TLS: tlscfg.Options{ + Enabled: true, + CertPath: testCertKeyLocation + "/example-server-cert.pem", + KeyPath: testCertKeyLocation + "/example-server-key.pem", + }, + clientTLS: tlscfg.Options{ + Enabled: true, + CAPath: testCertKeyLocation + "/example-CA-cert.pem", + ServerName: "example.com", + }, + expectError: false, + expectClientError: false, + expectServerFail: false, + }, + { + name: "should fail with TLS client without cert to trusted TLS server requiring cert", + TLS: tlscfg.Options{ + Enabled: true, + CertPath: testCertKeyLocation + "/example-server-cert.pem", + KeyPath: testCertKeyLocation + "/example-server-key.pem", + ClientCAPath: testCertKeyLocation + "/example-CA-cert.pem", + }, + clientTLS: tlscfg.Options{ + Enabled: true, + CAPath: testCertKeyLocation + "/example-CA-cert.pem", + ServerName: "example.com", + }, + expectError: false, + expectServerFail: false, + expectClientError: true, + }, + { + name: "should pass with TLS client with cert to trusted TLS server requiring cert", + TLS: tlscfg.Options{ + Enabled: true, + CertPath: testCertKeyLocation + "/example-server-cert.pem", + KeyPath: testCertKeyLocation + "/example-server-key.pem", + ClientCAPath: testCertKeyLocation + "/example-CA-cert.pem", + }, + clientTLS: tlscfg.Options{ + Enabled: true, + CAPath: testCertKeyLocation + "/example-CA-cert.pem", + ServerName: "example.com", + CertPath: testCertKeyLocation + "/example-client-cert.pem", + KeyPath: testCertKeyLocation + "/example-client-key.pem", + }, + expectError: false, + expectServerFail: false, + expectClientError: false, + }, + { + name: "should fail with TLS client without cert to trusted TLS server requiring cert from a different CA", + TLS: tlscfg.Options{ + Enabled: true, + CertPath: testCertKeyLocation + "/example-server-cert.pem", + KeyPath: testCertKeyLocation + "/example-server-key.pem", + ClientCAPath: testCertKeyLocation + "/wrong-CA-cert.pem", // NB: wrong CA + }, + clientTLS: tlscfg.Options{ + Enabled: true, + CAPath: testCertKeyLocation + "/example-CA-cert.pem", + ServerName: "example.com", + CertPath: testCertKeyLocation + "/example-client-cert.pem", + KeyPath: testCertKeyLocation + "/example-client-key.pem", + }, + expectError: false, + expectServerFail: false, + expectClientError: true, + }, +} + +type grpcClient struct { + storage_v1.SpanReaderPluginClient + + conn *grpc.ClientConn +} + +func newGRPCClient(t *testing.T, addr string, creds credentials.TransportCredentials, tm *tenancy.TenancyManager) *grpcClient { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + dialOpts := []grpc.DialOption{grpc.WithUnaryInterceptor(tenancy.NewClientUnaryInterceptor(tm))} + if creds != nil { + dialOpts = append(dialOpts, grpc.WithTransportCredentials(creds)) + } else { + dialOpts = append(dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + + conn, err := grpc.DialContext(ctx, addr, dialOpts...) + require.NoError(t, err) + + return &grpcClient{ + SpanReaderPluginClient: storage_v1.NewSpanReaderPluginClient(conn), + conn: conn, + } +} + +func TestServerGRPCTLS(t *testing.T) { + testlen := len(testCases) + + tests := make([]struct { + name string + TLS tlscfg.Options + clientTLS tlscfg.Options + expectError bool + expectClientError bool + expectServerFail bool + }, testlen) + copy(tests, testCases) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + serverOptions := &Options{ + GRPCHostPort: ":0", + TLSGRPC: test.TLS, + } + flagsSvc := flags.NewService(ports.QueryAdminHTTP) + flagsSvc.Logger = zap.NewNop() + + storageMocks := newStorageMocks() + expectedServices := []string{"test"} + storageMocks.reader.On("GetServices", mock.AnythingOfType("*context.valueCtx")).Return(expectedServices, nil) + + tm := tenancy.NewTenancyManager(&tenancy.Options{Enabled: true}) + server, err := NewServer( + serverOptions, + tm, + storageMocks.factory, + flagsSvc.Logger, + ) + assert.Nil(t, err) + assert.NoError(t, server.Start()) + + var wg sync.WaitGroup + wg.Add(1) + once := sync.Once{} + + go func() { + for s := range server.HealthCheckStatus() { + flagsSvc.HC().Set(s) + if s == healthcheck.Unavailable { + once.Do(func() { + wg.Done() + }) + } + } + }() + + var clientError error + var client *grpcClient + + if serverOptions.TLSGRPC.Enabled { + clientTLSCfg, err0 := test.clientTLS.Config(zap.NewNop()) + require.NoError(t, err0) + creds := credentials.NewTLS(clientTLSCfg) + client = newGRPCClient(t, server.grpcConn.Addr().String(), creds, tm) + } else { + client = newGRPCClient(t, server.grpcConn.Addr().String(), nil, tm) + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + ctx = tenancy.WithTenant(ctx, "foo") + res, clientError := client.GetServices(ctx, &storage_v1.GetServicesRequest{}) + + if test.expectClientError { + require.Error(t, clientError) + } else { + require.NoError(t, clientError) + assert.Equal(t, expectedServices, res.Services) + } + require.Nil(t, client.conn.Close()) + server.Close() + wg.Wait() + assert.Equal(t, healthcheck.Unavailable, flagsSvc.HC().Get()) + }) + } +} + +// func TestServerBadHostPort(t *testing.T) { +// _, err := NewServer(zap.NewNop(), &querysvc.QueryService{}, nil, +// &QueryOptions{HTTPHostPort: "8080", GRPCHostPort: "127.0.0.1:8081", BearerTokenPropagation: true}, +// tenancy.NewTenancyManager(&tenancy.Options{}), +// opentracing.NoopTracer{}) + +// assert.NotNil(t, err) +// _, err = NewServer(zap.NewNop(), &querysvc.QueryService{}, nil, +// &QueryOptions{HTTPHostPort: "127.0.0.1:8081", GRPCHostPort: "9123", BearerTokenPropagation: true}, +// tenancy.NewTenancyManager(&tenancy.Options{}), +// opentracing.NoopTracer{}) + +// assert.NotNil(t, err) +// } + +// func TestServerInUseHostPort(t *testing.T) { +// const availableHostPort = "127.0.0.1:0" +// conn, err := net.Listen("tcp", availableHostPort) +// require.NoError(t, err) +// defer func() { require.NoError(t, conn.Close()) }() + +// testCases := []struct { +// name string +// httpHostPort string +// grpcHostPort string +// }{ +// {"HTTP host port clash", conn.Addr().String(), availableHostPort}, +// {"GRPC host port clash", availableHostPort, conn.Addr().String()}, +// } +// for _, tc := range testCases { +// t.Run(tc.name, func(t *testing.T) { +// server, err := NewServer( +// zap.NewNop(), +// &querysvc.QueryService{}, +// nil, +// &QueryOptions{ +// HTTPHostPort: tc.httpHostPort, +// GRPCHostPort: tc.grpcHostPort, +// BearerTokenPropagation: true, +// }, +// tenancy.NewTenancyManager(&tenancy.Options{}), +// opentracing.NoopTracer{}, +// ) +// assert.NoError(t, err) + +// err = server.Start() +// assert.Error(t, err) + +// if server.grpcConn != nil { +// server.grpcConn.Close() +// } +// if server.httpConn != nil { +// server.httpConn.Close() +// } +// }) +// } +// } + +// func TestServerSinglePort(t *testing.T) { +// flagsSvc := flags.NewService(ports.QueryAdminHTTP) +// flagsSvc.Logger = zap.NewNop() +// hostPort := ports.GetAddressFromCLIOptions(ports.QueryHTTP, "") +// spanReader := &spanstoremocks.Reader{} +// dependencyReader := &depsmocks.Reader{} +// expectedServices := []string{"test"} +// spanReader.On("GetServices", mock.AnythingOfType("*context.valueCtx")).Return(expectedServices, nil) + +// querySvc := querysvc.NewQueryService(spanReader, dependencyReader, querysvc.QueryServiceOptions{}) +// server, err := NewServer(flagsSvc.Logger, querySvc, nil, +// &QueryOptions{GRPCHostPort: hostPort, HTTPHostPort: hostPort, BearerTokenPropagation: true}, +// tenancy.NewTenancyManager(&tenancy.Options{}), +// opentracing.NoopTracer{}) +// assert.Nil(t, err) +// assert.NoError(t, server.Start()) + +// var wg sync.WaitGroup +// wg.Add(1) +// once := sync.Once{} + +// go func() { +// for s := range server.HealthCheckStatus() { +// flagsSvc.HC().Set(s) +// if s == healthcheck.Unavailable { +// once.Do(func() { +// wg.Done() +// }) +// } + +// } +// wg.Done() +// }() + +// client := newGRPCClient(t, hostPort) +// defer client.conn.Close() + +// ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) +// defer cancel() + +// res, err := client.GetServices(ctx, &api_v2.GetServicesRequest{}) +// assert.NoError(t, err) +// assert.Equal(t, expectedServices, res.Services) + +// server.Close() +// wg.Wait() +// assert.Equal(t, healthcheck.Unavailable, flagsSvc.HC().Get()) +// } + +// func TestServerGracefulExit(t *testing.T) { +// flagsSvc := flags.NewService(ports.QueryAdminHTTP) + +// zapCore, logs := observer.New(zap.ErrorLevel) +// assert.Equal(t, 0, logs.Len(), "Expected initial ObservedLogs to have zero length.") + +// flagsSvc.Logger = zap.New(zapCore) +// hostPort := ports.PortToHostPort(ports.QueryAdminHTTP) + +// querySvc := &querysvc.QueryService{} +// tracer := opentracing.NoopTracer{} + +// server, err := NewServer(flagsSvc.Logger, querySvc, nil, &QueryOptions{GRPCHostPort: hostPort, HTTPHostPort: hostPort}, +// tenancy.NewTenancyManager(&tenancy.Options{}), tracer) +// assert.Nil(t, err) +// assert.NoError(t, server.Start()) +// go func() { +// for s := range server.HealthCheckStatus() { +// flagsSvc.HC().Set(s) +// } +// }() + +// // Wait for servers to come up before we can call .Close() +// // TODO Find a way to wait only as long as necessary. Unconditional sleep slows down the tests. +// time.Sleep(1 * time.Second) +// server.Close() + +// for _, logEntry := range logs.All() { +// assert.True(t, logEntry.Level != zap.ErrorLevel, +// "Error log found on server exit: %v", logEntry) +// } +// } + +func TestServerHandlesPortZero(t *testing.T) { + flagsSvc := flags.NewService(ports.QueryAdminHTTP) + zapCore, logs := observer.New(zap.InfoLevel) + flagsSvc.Logger = zap.New(zapCore) + storageMocks := newStorageMocks() + + server, err := NewServer( + &Options{GRPCHostPort: ":0"}, + tenancy.NewTenancyManager(&tenancy.Options{}), + storageMocks.factory, + flagsSvc.Logger, + ) + require.Nil(t, err) + require.NoError(t, server.Start()) + defer server.Close() + + const line = "Starting GRPC server" + message := logs.FilterMessage(line) + require.Equal(t, 1, message.Len(), "Expected '%s' log message, actual logs: %+v", line, logs) + + onlyEntry := message.All()[0] + hostPort := onlyEntry.ContextMap()["addr"].(string) + validateGRPCServer(t, hostPort, server.grpcServer) +} + +func validateGRPCServer(t *testing.T, hostPort string, server *grpc.Server) { + grpctest.ReflectionServiceValidator{ + HostPort: hostPort, + Server: server, + ExpectedServices: []string{ + "jaeger.storage.v1.SpanReaderPlugin", + "jaeger.storage.v1.SpanWriterPlugin", + "jaeger.storage.v1.DependenciesReaderPlugin", + "jaeger.storage.v1.PluginCapabilities", + "jaeger.storage.v1.ArchiveSpanReaderPlugin", + "jaeger.storage.v1.ArchiveSpanWriterPlugin", + "jaeger.storage.v1.StreamingSpanWriterPlugin", + // "grpc.health.v1.Health", + }, + }.Execute(t) +} diff --git a/pkg/tenancy/grpc.go b/pkg/tenancy/grpc.go index 0d269f59f78..468101fa2a9 100644 --- a/pkg/tenancy/grpc.go +++ b/pkg/tenancy/grpc.go @@ -113,3 +113,20 @@ func NewGuardingUnaryInterceptor(tc *TenancyManager) grpc.UnaryServerInterceptor return handler(WithTenant(ctx, tenant), req) } } + +// NewClientUnaryInterceptor injects tenant header into gRPC request metadata. +func NewClientUnaryInterceptor(tc *TenancyManager) grpc.UnaryClientInterceptor { + return grpc.UnaryClientInterceptor(func( + ctx context.Context, + method string, + req, reply interface{}, + cc *grpc.ClientConn, + invoker grpc.UnaryInvoker, + opts ...grpc.CallOption, + ) error { + if tenant := GetTenant(ctx); tenant != "" { + ctx = metadata.AppendToOutgoingContext(ctx, tc.Header, tenant) + } + return invoker(ctx, method, req, reply, cc, opts...) + }) +} diff --git a/pkg/tenancy/grpc_test.go b/pkg/tenancy/grpc_test.go index 1bdc898f945..ffb06f8723b 100644 --- a/pkg/tenancy/grpc_test.go +++ b/pkg/tenancy/grpc_test.go @@ -16,6 +16,7 @@ package tenancy import ( "context" + "errors" "testing" "github.com/stretchr/testify/assert" @@ -111,3 +112,22 @@ func TestTenancyInterceptors(t *testing.T) { }) } } + +func TestClientUnaryInterceptor(t *testing.T) { + tm := NewTenancyManager(&Options{Enabled: true, Tenants: []string{"acme"}}) + interceptor := NewClientUnaryInterceptor(tm) + var tenant string + fakeErr := errors.New("foo") + invoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + md, ok := metadata.FromOutgoingContext(ctx) + assert.True(t, ok) + ten, err := tenantFromMetadata(md, tm.Header) + require.NoError(t, err) + tenant = ten + return fakeErr + } + ctx := WithTenant(context.Background(), "acme") + err := interceptor(ctx, "method", "request", "responce", nil, invoker) + assert.Equal(t, "acme", tenant) + assert.Same(t, fakeErr, err) +}