Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean-up APIv3 tests #5046

Merged
merged 1 commit into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/query/app/apiv3/grpc_gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import (
)

// RegisterGRPCGateway registers api_v3 endpoints into provided mux.
func RegisterGRPCGateway(ctx context.Context, logger *zap.Logger, r *mux.Router, basePath string, grpcEndpoint string, grpcTLS tlscfg.Options, tm *tenancy.Manager) error {
func RegisterGRPCGateway(ctx context.Context, logger *zap.Logger, r *mux.Router, basePath string, grpcEndpoint string, grpcTLS *tlscfg.Options, tm *tenancy.Manager) error {
grpcEndpoint = netutils.FixLocalhost([]string{grpcEndpoint})[0]
jsonpb := &runtime.JSONPb{}

Expand Down
115 changes: 73 additions & 42 deletions cmd/query/app/apiv3/grpc_gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
package apiv3

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"strings"
Expand All @@ -27,7 +27,6 @@ import (
"github.com/gorilla/mux"
"github.com/grpc-ecosystem/grpc-gateway/runtime"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"google.golang.org/grpc"
Expand All @@ -45,23 +44,32 @@ import (

var testCertKeyLocation = "../../../../pkg/config/tlscfg/testdata/"

func testGRPCGateway(t *testing.T, basePath string, serverTLS tlscfg.Options, clientTLS tlscfg.Options) {
testGRPCGatewayWithTenancy(t, basePath, serverTLS, clientTLS,
tenancy.Options{
Enabled: false,
},
func(*http.Request) {})
type testGateway struct {
reader *spanstoremocks.Reader
url string
}

func setupGRPCGateway(t *testing.T, basePath string, serverTLS tlscfg.Options, clientTLS tlscfg.Options, tenancyOptions tenancy.Options) (*spanstoremocks.Reader, net.Listener, *grpc.Server, context.CancelFunc, *http.Server) {
r := &spanstoremocks.Reader{}
func setupGRPCGateway(
t *testing.T,
basePath string,
serverTLS, clientTLS *tlscfg.Options,
tenancyOptions tenancy.Options,
) *testGateway {
// *spanstoremocks.Reader, net.Listener, *grpc.Server, context.CancelFunc, *http.Server
gw := &testGateway{
reader: &spanstoremocks.Reader{},
}

q := querysvc.NewQueryService(r, &dependencyStoreMocks.Reader{}, querysvc.QueryServiceOptions{})
q := querysvc.NewQueryService(gw.reader,
&dependencyStoreMocks.Reader{},
querysvc.QueryServiceOptions{},
)

var serverGRPCOpts []grpc.ServerOption
if serverTLS.Enabled {
config, err := serverTLS.Config(zap.NewNop())
require.NoError(t, err)
t.Cleanup(func() { serverTLS.Close() })
creds := credentials.NewTLS(config)
serverGRPCOpts = append(serverGRPCOpts, grpc.Creds(creds))
}
Expand All @@ -77,17 +85,25 @@ func setupGRPCGateway(t *testing.T, basePath string, serverTLS tlscfg.Options, c
QueryService: q,
}
api_v3.RegisterQueryServiceServer(grpcServer, h)
lis, _ := net.Listen("tcp", ":0")
lis, err := net.Listen("tcp", ":0")
require.NoError(t, err)

go func() {
err := grpcServer.Serve(lis)
require.NoError(t, err)
}()
t.Cleanup(func() { grpcServer.Stop() })

router := &mux.Router{}
router = router.PathPrefix(basePath).Subrouter()
ctx, cancel := context.WithCancel(context.Background())
err := RegisterGRPCGateway(ctx, zap.NewNop(), router, basePath, lis.Addr().String(), clientTLS, tenancy.NewManager(&tenancyOptions))
err = RegisterGRPCGateway(
ctx, zap.NewNop(), router, basePath,
lis.Addr().String(), clientTLS, tenancy.NewManager(&tenancyOptions),
)
require.NoError(t, err)
t.Cleanup(func() { cancel() })
t.Cleanup(func() { clientTLS.Close() })

httpLis, err := net.Listen("tcp", ":0")
require.NoError(t, err)
Expand All @@ -98,23 +114,39 @@ func setupGRPCGateway(t *testing.T, basePath string, serverTLS tlscfg.Options, c
err = httpServer.Serve(httpLis)
require.Equal(t, http.ErrServerClosed, err)
}()
return r, httpLis, grpcServer, cancel, httpServer
t.Cleanup(func() { httpServer.Shutdown(context.Background()) })

gw.url = fmt.Sprintf(
"http://localhost%s%s",
strings.Replace(httpLis.Addr().String(), "[::]", "", 1),
basePath)
return gw
}

func testGRPCGatewayWithTenancy(t *testing.T, basePath string, serverTLS tlscfg.Options, clientTLS tlscfg.Options,
func testGRPCGateway(
t *testing.T, basePath string,
serverTLS, clientTLS *tlscfg.Options,
) {
testGRPCGatewayWithTenancy(t, basePath, serverTLS, clientTLS,
tenancy.Options{
Enabled: false,
},
func(*http.Request) { /* setupRequest : no changes */ },
)
}

func testGRPCGatewayWithTenancy(
t *testing.T,
basePath string,
serverTLS *tlscfg.Options,
clientTLS *tlscfg.Options,
tenancyOptions tenancy.Options,
setupRequest func(*http.Request),
) {
defer serverTLS.Close()
defer clientTLS.Close()

reader, httpLis, grpcServer, cancel, httpServer := setupGRPCGateway(t, basePath, serverTLS, clientTLS, tenancyOptions)
defer grpcServer.Stop()
defer cancel()
defer httpServer.Shutdown(context.Background())
gw := setupGRPCGateway(t, basePath, serverTLS, clientTLS, tenancyOptions)

traceID := model.NewTraceID(150, 160)
reader.On("GetTrace", mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("model.TraceID")).Return(
gw.reader.On("GetTrace", matchContext, matchTraceID).Return(
&model.Trace{
Spans: []*model.Span{
{
Expand All @@ -125,28 +157,28 @@ func testGRPCGatewayWithTenancy(t *testing.T, basePath string, serverTLS tlscfg.
},
}, nil).Once()

req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://localhost%s%s/api/v3/traces/123", strings.Replace(httpLis.Addr().String(), "[::]", "", 1), basePath), nil)
req, err := http.NewRequest(http.MethodGet, gw.url+"/api/v3/traces/123", nil)
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")
setupRequest(req)
response, err := http.DefaultClient.Do(req)
require.NoError(t, err)
buf := bytes.Buffer{}
_, err = buf.ReadFrom(response.Body)
body, err := io.ReadAll(response.Body)
require.NoError(t, err)
require.NoError(t, response.Body.Close())

jsonpb := &runtime.JSONPb{}
var envelope envelope
err = json.Unmarshal(buf.Bytes(), &envelope)
err = json.Unmarshal(body, &envelope)
require.NoError(t, err)
var spansResponse api_v3.SpansResponseChunk
err = jsonpb.Unmarshal(envelope.Result, &spansResponse)
require.NoError(t, err)
assert.Len(t, spansResponse.GetResourceSpans(), 1)
assert.Equal(t, uint64ToTraceID(t, traceID.High, traceID.Low), spansResponse.GetResourceSpans()[0].GetScopeSpans()[0].GetSpans()[0].GetTraceId())
assert.Equal(t, bytesOfTraceID(t, traceID.High, traceID.Low), spansResponse.GetResourceSpans()[0].GetScopeSpans()[0].GetSpans()[0].GetTraceId())
}

func uint64ToTraceID(t *testing.T, high, low uint64) []byte {
func bytesOfTraceID(t *testing.T, high, low uint64) []byte {
traceID := model.NewTraceID(high, low)
buf := make([]byte, 16)
_, err := traceID.MarshalTo(buf)
Expand All @@ -155,17 +187,17 @@ func uint64ToTraceID(t *testing.T, high, low uint64) []byte {
}

func TestGRPCGateway(t *testing.T) {
testGRPCGateway(t, "/", tlscfg.Options{}, tlscfg.Options{})
testGRPCGateway(t, "/", &tlscfg.Options{}, &tlscfg.Options{})
}

func TestGRPCGateway_TLS_with_base_path(t *testing.T) {
serverTLS := tlscfg.Options{
func TestGRPCGatewayWithBasePathAndTLS(t *testing.T) {
serverTLS := &tlscfg.Options{
Enabled: true,
CAPath: testCertKeyLocation + "/example-CA-cert.pem",
CertPath: testCertKeyLocation + "/example-server-cert.pem",
KeyPath: testCertKeyLocation + "/example-server-key.pem",
}
clientTLS := tlscfg.Options{
clientTLS := &tlscfg.Options{
Enabled: true,
CAPath: testCertKeyLocation + "/example-CA-cert.pem",
CertPath: testCertKeyLocation + "/example-client-cert.pem",
Expand All @@ -180,12 +212,12 @@ type envelope struct {
Result json.RawMessage `json:"result"`
}

func TestTenancyGRPCGateway(t *testing.T) {
func TestGRPCGatewayWithTenancy(t *testing.T) {
tenancyOptions := tenancy.Options{
Enabled: true,
}
tm := tenancy.NewManager(&tenancyOptions)
testGRPCGatewayWithTenancy(t, "/", tlscfg.Options{}, tlscfg.Options{},
testGRPCGatewayWithTenancy(t, "/", &tlscfg.Options{}, &tlscfg.Options{},
// Configure the gateway to forward tenancy header from HTTP to GRPC
tenancyOptions,
// Add a tenancy header on outbound requests
Expand All @@ -197,15 +229,12 @@ func TestTenancyGRPCGateway(t *testing.T) {
func TestTenancyGRPCRejection(t *testing.T) {
basePath := "/"
tenancyOptions := tenancy.Options{Enabled: true}
reader, httpLis, grpcServer, cancel, httpServer := setupGRPCGateway(t,
basePath, tlscfg.Options{}, tlscfg.Options{},
gw := setupGRPCGateway(t,
basePath, &tlscfg.Options{}, &tlscfg.Options{},
tenancyOptions)
defer grpcServer.Stop()
defer cancel()
defer httpServer.Shutdown(context.Background())

traceID := model.NewTraceID(150, 160)
reader.On("GetTrace", mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("model.TraceID")).Return(
gw.reader.On("GetTrace", matchContext, matchTraceID).Return(
&model.Trace{
Spans: []*model.Span{
{
Expand All @@ -216,19 +245,21 @@ func TestTenancyGRPCRejection(t *testing.T) {
},
}, nil).Once()

req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://localhost%s%s/api/v3/traces/123", strings.Replace(httpLis.Addr().String(), "[::]", "", 1), basePath), nil)
req, err := http.NewRequest(http.MethodGet, gw.url+"/api/v3/traces/123", nil)
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")
// We don't set tenant header
response, err := http.DefaultClient.Do(req)
require.NoError(t, err)
require.NoError(t, response.Body.Close())
require.Equal(t, http.StatusForbidden, response.StatusCode)

// Try again with tenant header set
tm := tenancy.NewManager(&tenancyOptions)
req.Header.Set(tm.Header, "acme")
response, err = http.DefaultClient.Do(req)
require.NoError(t, err)
require.NoError(t, response.Body.Close())
require.Equal(t, http.StatusOK, response.StatusCode)
// Skip unmarshal of response; it is enough that it succeeded
}
Loading
Loading