Skip to content

Commit

Permalink
wrap middleware in an order asserting interceptor
Browse files Browse the repository at this point in the history
  • Loading branch information
vroldanbet committed Jun 13, 2023
1 parent 3983967 commit 8a89d6f
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 26 deletions.
156 changes: 130 additions & 26 deletions pkg/cmd/server/defaults.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package server

import (
"context"
"encoding/json"
"flag"
"fmt"
middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
"net/http"
"net/http/pprof"

Expand Down Expand Up @@ -195,52 +198,65 @@ func DefaultUnaryMiddleware(logger zerolog.Logger, authFunc grpcauth.AuthFunc, e
func DefaultStreamingMiddleware(logger zerolog.Logger, authFunc grpcauth.AuthFunc, enableVersionResponse bool, dispatcher dispatch.Dispatcher, ds datastore.Datastore) (*MiddlewareChain[grpc.StreamServerInterceptor], error) {
chain, err := NewMiddlewareChain([]ReferenceableMiddleware[grpc.StreamServerInterceptor]{
{
Name: DefaultMiddlewareRequestID,
Middleware: requestid.StreamServerInterceptor(requestid.GenerateIfMissing(true)),
Name: DefaultMiddlewareRequestID,
Middleware: MustStreamingOrder(requestid.StreamServerInterceptor(requestid.GenerateIfMissing(true)),
DefaultMiddlewareRequestID, DefaultMiddlewareLog, ""),
},
{
Name: DefaultMiddlewareLog,
Middleware: logmw.StreamServerInterceptor(logmw.ExtractMetadataField("x-request-id", "requestID")),
Name: DefaultMiddlewareLog,
Middleware: MustStreamingOrder(
logmw.StreamServerInterceptor(logmw.ExtractMetadataField("x-request-id", "requestID")),
DefaultMiddlewareLog, DefaultMiddlewareGRPCLog, DefaultMiddlewareRequestID,
),
},
{
Name: DefaultMiddlewareGRPCLog,
Middleware: grpclog.StreamServerInterceptor(grpczerolog.InterceptorLogger(logger), defaultGRPCLogOptions...),
Name: DefaultMiddlewareGRPCLog,
Middleware: MustStreamingOrder(grpclog.StreamServerInterceptor(grpczerolog.InterceptorLogger(logger), defaultGRPCLogOptions...),
DefaultMiddlewareGRPCLog, DefaultMiddlewareOTelGRPC, DefaultMiddlewareLog),
},
{
Name: DefaultMiddlewareOTelGRPC,
Middleware: otelgrpc.StreamServerInterceptor(),
Name: DefaultMiddlewareOTelGRPC,
Middleware: MustStreamingOrder(otelgrpc.StreamServerInterceptor(),
DefaultMiddlewareOTelGRPC, DefaultMiddlewareGRPCProm, DefaultMiddlewareGRPCLog),
},
{
Name: DefaultMiddlewareGRPCProm,
Middleware: grpcprom.StreamServerInterceptor,
Name: DefaultMiddlewareGRPCProm,
Middleware: MustStreamingOrder(grpcprom.StreamServerInterceptor,
DefaultMiddlewareGRPCProm, DefaultMiddlewareGRPCAuth, DefaultMiddlewareOTelGRPC),
},
{
Name: DefaultMiddlewareGRPCAuth,
Middleware: grpcauth.StreamServerInterceptor(authFunc),
Name: DefaultMiddlewareGRPCAuth,
Middleware: MustStreamingOrder(grpcauth.StreamServerInterceptor(authFunc),
DefaultMiddlewareGRPCAuth, DefaultMiddlewareServerVersion, DefaultMiddlewareGRPCProm),
},
{
Name: DefaultMiddlewareServerVersion,
Middleware: serverversion.StreamServerInterceptor(enableVersionResponse),
Name: DefaultMiddlewareServerVersion,
Middleware: MustStreamingOrder(serverversion.StreamServerInterceptor(enableVersionResponse),
DefaultMiddlewareServerVersion, DefaultInternalMiddlewareDispatch, DefaultMiddlewareGRPCAuth),
},
{
Name: DefaultInternalMiddlewareDispatch,
Internal: true,
Middleware: dispatchmw.StreamServerInterceptor(dispatcher),
Name: DefaultInternalMiddlewareDispatch,
Internal: true,
Middleware: MustStreamingOrder(dispatchmw.StreamServerInterceptor(dispatcher),
DefaultInternalMiddlewareDispatch, DefaultInternalMiddlewareDatastore, DefaultMiddlewareServerVersion),
},
{
Name: DefaultInternalMiddlewareDatastore,
Internal: true,
Middleware: datastoremw.StreamServerInterceptor(ds),
Name: DefaultInternalMiddlewareDatastore,
Internal: true,
Middleware: MustStreamingOrder(datastoremw.StreamServerInterceptor(ds),
DefaultInternalMiddlewareDatastore, DefaultInternalMiddlewareConsistency, DefaultInternalMiddlewareDispatch),
},
{
Name: DefaultInternalMiddlewareConsistency,
Internal: true,
Middleware: consistencymw.StreamServerInterceptor(),
Name: DefaultInternalMiddlewareConsistency,
Internal: true,
Middleware: MustStreamingOrder(consistencymw.StreamServerInterceptor(),
DefaultInternalMiddlewareConsistency, DefaultInternalMiddlewareServerSpecific, DefaultInternalMiddlewareDatastore),
},
{
Name: DefaultInternalMiddlewareServerSpecific,
Internal: true,
Middleware: servicespecific.StreamServerInterceptor,
Name: DefaultInternalMiddlewareServerSpecific,
Internal: true,
Middleware: MustStreamingOrder(servicespecific.StreamServerInterceptor,
DefaultInternalMiddlewareServerSpecific, "", DefaultInternalMiddlewareConsistency),
},
}...)
return &chain, err
Expand Down Expand Up @@ -268,3 +284,91 @@ func DefaultDispatchMiddleware(logger zerolog.Logger, authFunc grpcauth.AuthFunc
servicespecific.StreamServerInterceptor,
}
}

func inTest() bool {
return flag.Lookup("test.v") != nil
}

type streamOrderAssertion struct {
grpc.ServerStream
name string
alreadyExecuted string
notExecuted string
}

func (o streamOrderAssertion) RecvMsg(m any) error {
mustHaveExecuted(o.Context(), o.alreadyExecuted)
mustHaveNotExecuted(o.Context(), o.notExecuted)
markAsExecuted(o.Context(), o.name)
err := o.ServerStream.RecvMsg(m)
return err
}

func (o streamOrderAssertion) SendMsg(m any) error {
return o.ServerStream.SendMsg(m)
}

func MustStreamingOrder(interceptor grpc.StreamServerInterceptor, name, alreadyExecuted, notExecuted string) grpc.StreamServerInterceptor {
if !inTest() {
return interceptor
}
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
wss := middleware.WrapServerStream(ss)
if wss.WrappedContext.Value(interceptorsExecuted) == nil {
handle := executedHandle{executed: make(map[string]struct{}, 0)}
wss.WrappedContext = context.WithValue(wss.WrappedContext, interceptorsExecuted, &handle)
}
wrappedStream := streamOrderAssertion{
ServerStream: wss,
name: name,
alreadyExecuted: alreadyExecuted,
notExecuted: notExecuted,
}
return interceptor(srv, wrappedStream, info, handler)
}
}

func mustHaveNotExecuted(ctx context.Context, notExecuted string) {
if notExecuted == "" {
return
}
val := ctx.Value(interceptorsExecuted)
if val == nil {
return
}
handle := val.(*executedHandle)
if _, ok := handle.executed[notExecuted]; ok {
panic("expected interceptor " + notExecuted + " to be not executed")
}
}

func mustHaveExecuted(ctx context.Context, expectedExecuted string) {
if expectedExecuted == "" {
return
}
val := ctx.Value(interceptorsExecuted)
if val == nil {
panic("expected interceptor " + expectedExecuted + " to be executed")
}
handle := val.(*executedHandle)
if _, ok := handle.executed[expectedExecuted]; ok {
return
}
panic("expected interceptor " + expectedExecuted + " to be executed")
}

func markAsExecuted(ctx context.Context, name string) {
val := ctx.Value(interceptorsExecuted)
if val == nil {
panic("handle should exist")
} else {
handle := val.(*executedHandle)
handle.executed[name] = struct{}{}
}
}

type executedHandle struct {
executed map[string]struct{}
}

var interceptorsExecuted = struct{}{}
52 changes: 52 additions & 0 deletions pkg/cmd/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package server
import (
"context"
"errors"
v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
"github.com/authzed/spicedb/pkg/cmd/datastore"
"testing"
"time"

Expand Down Expand Up @@ -53,6 +55,56 @@ func TestServerGracefulTermination(t *testing.T) {
<-ch
}

func TestStreamingMiddlewareOrder(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

ds, err := datastore.NewDatastore(ctx,
datastore.DefaultDatastoreConfig().ToOption(),
datastore.WithBootstrapFiles("testdata/test_schema.yaml"),
datastore.WithRequestHedgingEnabled(false),
)
require.NoError(t, err)

c := ConfigWithOptions(
&Config{},
WithPresharedSecureKey("psk"),
WithDatastore(ds),
WithGRPCServer(util.GRPCServerConfig{
Network: util.BufferedNetwork,
Enabled: true,
}),
)
rs, err := c.Complete(ctx)
require.NoError(t, err)

clientConn, err := rs.GRPCDialContext(ctx)
require.NoError(t, err)

psc := v1.NewPermissionsServiceClient(clientConn)

go func() {
_ = rs.Run(ctx)
}()
time.Sleep(100 * time.Millisecond)

req := &v1.LookupResourcesRequest{
ResourceObjectType: "resource",
Subject: &v1.SubjectReference{
Object: &v1.ObjectReference{
ObjectType: "user",
ObjectId: "user1",
},
},
Permission: "read",
}
lrc, err := psc.LookupResources(ctx, req)
require.NoError(t, err)

_, err = lrc.Recv()
require.NoError(t, err)
}

func TestServerGracefulTerminationOnError(t *testing.T) {
defer goleak.VerifyNone(t, goleak.IgnoreCurrent())

Expand Down
12 changes: 12 additions & 0 deletions pkg/cmd/server/testdata/test_schema.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
schema: |-
definition user {}
definition resource {
relation reader: user
permission read = reader
}
relationships: |
resource:doc1#reader@user:user1

0 comments on commit 8a89d6f

Please sign in to comment.