Skip to content

Commit

Permalink
feat: tracing for gRPC middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
alnr committed Apr 3, 2023
1 parent efa77fe commit b32e7b7
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 66 deletions.
2 changes: 1 addition & 1 deletion driver/registry_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ type RegistryMemory struct {
}

func (r *RegistryMemory) Init() {
_ = r.Tracer() // make sure tracer is initialized
if err := r.RuleFetcher().Watch(context.Background()); err != nil {
r.Logger().WithError(err).Fatal("Access rule watcher could not be initialized.")
}
_ = r.RuleRepository()
_ = r.Tracer() // make sure tracer is initialized
}

func (r *RegistryMemory) RuleFetcher() rule.Fetcher {
Expand Down
142 changes: 85 additions & 57 deletions middleware/grpc_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ import (
"net/url"
"strings"

"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"

Expand Down Expand Up @@ -49,85 +52,110 @@ func (m *middleware) httpRequest(ctx context.Context, fullMethod string) (*http.
Scheme: "grpc",
}

return &http.Request{
return (&http.Request{
Method: "POST",
Proto: "HTTP/2",
ProtoMajor: 2,
URL: u,
Host: u.Host,
Header: header,
}, nil
}).WithContext(ctx), nil
}

var (
_ grpc.UnaryServerInterceptor = new(middleware).unaryInterceptor
_ grpc.StreamServerInterceptor = new(middleware).streamInterceptor
)

// UnaryInterceptor returns the gRPC unary interceptor of the middleware.
func (m *middleware) UnaryInterceptor() grpc.UnaryServerInterceptor {
return func(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (resp interface{}, err error) {

log := m.Logger().WithField("middleware", "oathkeeper")

httpReq, err := m.httpRequest(ctx, info.FullMethod)
if err != nil {
log.WithError(err).Warn("could not build HTTP request")
return nil, ErrDenied
}
log = log.WithRequest(httpReq)
return m.unaryInterceptor
}

log.Debug("matching HTTP request build from gRPC")
func (m *middleware) unaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
traceCtx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("oathkeeper/middleware").Start(ctx, "Oathkeeper.UnaryInterceptor")
defer span.End()

r, err := m.RuleMatcher().Match(ctx, httpReq.Method, httpReq.URL, rule.ProtocolGRPC)
if err != nil {
log.WithError(err).Warn("could not find a matching rule")
return nil, ErrDenied
}
log := m.Logger().WithField("middleware", "oathkeeper")

_, err = m.ProxyRequestHandler().HandleRequest(httpReq, r)
if err != nil {
log.WithError(err).Warn("failed to handle request")
return nil, ErrDenied
}
httpReq, err := m.httpRequest(traceCtx, info.FullMethod)
if err != nil {
log.WithError(err).Warn("could not build HTTP request")
span.SetAttributes(attribute.String("oathkeeper.verdict", "denied"))
span.SetStatus(codes.Error, err.Error())
return nil, ErrDenied
}
log = log.WithRequest(httpReq)

log.Debug("matching HTTP request build from gRPC")

r, err := m.RuleMatcher().Match(traceCtx, httpReq.Method, httpReq.URL, rule.ProtocolGRPC)
if err != nil {
log.WithError(err).Warn("could not find a matching rule")
span.SetAttributes(attribute.String("oathkeeper.verdict", "denied"))
span.SetStatus(codes.Error, err.Error())
return nil, ErrDenied
}

log.Info("access request granted")
return handler(ctx, req)
_, err = m.ProxyRequestHandler().HandleRequest(httpReq, r)
if err != nil {
log.WithError(err).Warn("failed to handle request")
span.SetAttributes(attribute.String("oathkeeper.verdict", "denied"))
span.SetStatus(codes.Error, err.Error())
return nil, ErrDenied
}

log.Info("access request granted")
span.SetAttributes(attribute.String("oathkeeper.verdict", "allowed"))
span.End()
return handler(ctx, req)
}

// StreamInterceptor returns the gRPC stream interceptor of the middleware.
func (m *middleware) StreamInterceptor() grpc.StreamServerInterceptor {
return func(
srv interface{},
stream grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler) error {

log := m.Logger().WithField("middleware", "oathkeeper")
ctx := stream.Context()

httpReq, err := m.httpRequest(ctx, info.FullMethod)
if err != nil {
log.WithError(err).Warn("could not build HTTP request")
return ErrDenied
}
log = log.WithRequest(httpReq)
return m.streamInterceptor
}

log.Debug("matching HTTP request build from gRPC")
func (m *middleware) streamInterceptor(
srv interface{},
stream grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler) (err error) {
ctx := stream.Context()
ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("oathkeeper/middleware").Start(ctx, "Oathkeeper.UnaryInterceptor")
defer span.End()

log := m.Logger().WithField("middleware", "oathkeeper")

httpReq, err := m.httpRequest(ctx, info.FullMethod)
if err != nil {
log.WithError(err).Warn("could not build HTTP request")
span.SetAttributes(attribute.String("oathkeeper.verdict", "denied"))
span.SetStatus(codes.Error, err.Error())
return ErrDenied
}
log = log.WithRequest(httpReq)

r, err := m.RuleMatcher().Match(ctx, httpReq.Method, httpReq.URL, rule.ProtocolGRPC)
if err != nil {
log.WithError(err).Warn("could not find a matching rule")
return ErrDenied
}
log.Debug("matching HTTP request build from gRPC")

_, err = m.ProxyRequestHandler().HandleRequest(httpReq, r)
if err != nil {
log.WithError(err).Warn("failed to handle request")
return ErrDenied
}
r, err := m.RuleMatcher().Match(ctx, httpReq.Method, httpReq.URL, rule.ProtocolGRPC)
if err != nil {
log.WithError(err).Warn("could not find a matching rule")
span.SetAttributes(attribute.String("oathkeeper.verdict", "denied"))
span.SetStatus(codes.Error, err.Error())
return ErrDenied
}

log.Info("access request granted")
return handler(srv, stream)
_, err = m.ProxyRequestHandler().HandleRequest(httpReq, r)
if err != nil {
log.WithError(err).Warn("failed to handle request")
span.SetAttributes(attribute.String("oathkeeper.verdict", "denied"))
span.SetStatus(codes.Error, err.Error())
return ErrDenied
}

log.Info("access request granted")
span.SetAttributes(attribute.String("oathkeeper.verdict", "allowed"))
span.End()
return handler(srv, stream)
}
15 changes: 7 additions & 8 deletions middleware/grpc_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@ import (
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/bufconn"
"google.golang.org/grpc/test/grpc_testing"
grpcTesting "google.golang.org/grpc/test/grpc_testing"

"github.com/ory/oathkeeper/driver"
"github.com/ory/oathkeeper/driver/configuration"
"github.com/ory/oathkeeper/middleware"
"github.com/ory/oathkeeper/rule"
)

func testClient(t *testing.T, l *bufconn.Listener, dialOpts ...grpc.DialOption) grpcTesting.TestServiceClient {
func testClient(t *testing.T, l *bufconn.Listener, dialOpts ...grpc.DialOption) grpc_testing.TestServiceClient {
conn, err := grpc.Dial("bufnet",
append(dialOpts,
grpc.WithTransportCredentials(insecure.NewCredentials()),
Expand All @@ -45,7 +44,7 @@ func testClient(t *testing.T, l *bufconn.Listener, dialOpts ...grpc.DialOption)
require.NoError(t, err)
t.Cleanup(func() { conn.Close() })

return grpcTesting.NewTestServiceClient(conn)
return grpc_testing.NewTestServiceClient(conn)
}

func testTokenCheckServer(t *testing.T) *httptest.Server {
Expand Down Expand Up @@ -126,7 +125,7 @@ mutators:
grpc.UnaryInterceptor(mw.UnaryInterceptor()),
grpc.StreamInterceptor(mw.StreamInterceptor()),
)
grpcTesting.RegisterTestServiceServer(s, upstream)
grpc_testing.RegisterTestServiceServer(s, upstream)
go func() {
if err := s.Serve(l); err != nil {
t.Logf("Server exited with error: %v", err)
Expand All @@ -137,7 +136,7 @@ mutators:
upstream.EXPECT().
EmptyCall(gomock.Any(), gomock.Any()).
AnyTimes().
Return(&grpcTesting.Empty{}, nil)
Return(&grpc_testing.Empty{}, nil)

cases := []struct {
name string
Expand Down Expand Up @@ -337,13 +336,13 @@ mutators:
require.NoError(t, reg.RuleRepository().SetMatchingStrategy(ctx, s))
require.NoError(t, reg.RuleRepository().Set(ctx, tc.rules[s]))

_, err := client.EmptyCall(ctx, &grpcTesting.Empty{})
_, err := client.EmptyCall(ctx, &grpc_testing.Empty{})
tc.assert(t, err)

_, err = client.UnaryCall(ctx, &grpcTesting.SimpleRequest{})
_, err = client.UnaryCall(ctx, &grpc_testing.SimpleRequest{})
assertErrDenied(t, err)

stream, _ := client.StreamingOutputCall(ctx, &grpcTesting.StreamingOutputCallRequest{})
stream, _ := client.StreamingOutputCall(ctx, &grpc_testing.StreamingOutputCallRequest{})
_, err = stream.Recv()
assertErrDenied(t, err)
})
Expand Down

0 comments on commit b32e7b7

Please sign in to comment.