diff --git a/driver/registry_memory.go b/driver/registry_memory.go index 288b122f31..b073833ec7 100644 --- a/driver/registry_memory.go +++ b/driver/registry_memory.go @@ -66,11 +66,11 @@ type RegistryMemory struct { } func (r *RegistryMemory) Init() { + _ = r.Tracer() // make sure tracer is initialized if err := r.RuleFetcher().Watch(context.Background()); err != nil { r.Logger().WithError(err).Fatal("Access rule watcher could not be initialized.") } _ = r.RuleRepository() - _ = r.Tracer() // make sure tracer is initialized } func (r *RegistryMemory) RuleFetcher() rule.Fetcher { diff --git a/middleware/grpc_middleware.go b/middleware/grpc_middleware.go index 99380a5cdf..45fec6380d 100644 --- a/middleware/grpc_middleware.go +++ b/middleware/grpc_middleware.go @@ -10,6 +10,9 @@ import ( "net/url" "strings" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" "google.golang.org/grpc" "google.golang.org/grpc/metadata" @@ -49,85 +52,110 @@ func (m *middleware) httpRequest(ctx context.Context, fullMethod string) (*http. Scheme: "grpc", } - return &http.Request{ + return (&http.Request{ Method: "POST", Proto: "HTTP/2", ProtoMajor: 2, URL: u, Host: u.Host, Header: header, - }, nil + }).WithContext(ctx), nil } +var ( + _ grpc.UnaryServerInterceptor = new(middleware).unaryInterceptor + _ grpc.StreamServerInterceptor = new(middleware).streamInterceptor +) + // UnaryInterceptor returns the gRPC unary interceptor of the middleware. func (m *middleware) UnaryInterceptor() grpc.UnaryServerInterceptor { - return func( - ctx context.Context, - req interface{}, - info *grpc.UnaryServerInfo, - handler grpc.UnaryHandler) (resp interface{}, err error) { - - log := m.Logger().WithField("middleware", "oathkeeper") - - httpReq, err := m.httpRequest(ctx, info.FullMethod) - if err != nil { - log.WithError(err).Warn("could not build HTTP request") - return nil, ErrDenied - } - log = log.WithRequest(httpReq) + return m.unaryInterceptor +} - log.Debug("matching HTTP request build from gRPC") +func (m *middleware) unaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + traceCtx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("oathkeeper/middleware").Start(ctx, "Oathkeeper.UnaryInterceptor") + defer span.End() - r, err := m.RuleMatcher().Match(ctx, httpReq.Method, httpReq.URL, rule.ProtocolGRPC) - if err != nil { - log.WithError(err).Warn("could not find a matching rule") - return nil, ErrDenied - } + log := m.Logger().WithField("middleware", "oathkeeper") - _, err = m.ProxyRequestHandler().HandleRequest(httpReq, r) - if err != nil { - log.WithError(err).Warn("failed to handle request") - return nil, ErrDenied - } + httpReq, err := m.httpRequest(traceCtx, info.FullMethod) + if err != nil { + log.WithError(err).Warn("could not build HTTP request") + span.SetAttributes(attribute.String("oathkeeper.verdict", "denied")) + span.SetStatus(codes.Error, err.Error()) + return nil, ErrDenied + } + log = log.WithRequest(httpReq) + + log.Debug("matching HTTP request build from gRPC") + + r, err := m.RuleMatcher().Match(traceCtx, httpReq.Method, httpReq.URL, rule.ProtocolGRPC) + if err != nil { + log.WithError(err).Warn("could not find a matching rule") + span.SetAttributes(attribute.String("oathkeeper.verdict", "denied")) + span.SetStatus(codes.Error, err.Error()) + return nil, ErrDenied + } - log.Info("access request granted") - return handler(ctx, req) + _, err = m.ProxyRequestHandler().HandleRequest(httpReq, r) + if err != nil { + log.WithError(err).Warn("failed to handle request") + span.SetAttributes(attribute.String("oathkeeper.verdict", "denied")) + span.SetStatus(codes.Error, err.Error()) + return nil, ErrDenied } + + log.Info("access request granted") + span.SetAttributes(attribute.String("oathkeeper.verdict", "allowed")) + span.End() + return handler(ctx, req) } // StreamInterceptor returns the gRPC stream interceptor of the middleware. func (m *middleware) StreamInterceptor() grpc.StreamServerInterceptor { - return func( - srv interface{}, - stream grpc.ServerStream, - info *grpc.StreamServerInfo, - handler grpc.StreamHandler) error { - - log := m.Logger().WithField("middleware", "oathkeeper") - ctx := stream.Context() - - httpReq, err := m.httpRequest(ctx, info.FullMethod) - if err != nil { - log.WithError(err).Warn("could not build HTTP request") - return ErrDenied - } - log = log.WithRequest(httpReq) + return m.streamInterceptor +} - log.Debug("matching HTTP request build from gRPC") +func (m *middleware) streamInterceptor( + srv interface{}, + stream grpc.ServerStream, + info *grpc.StreamServerInfo, + handler grpc.StreamHandler) (err error) { + ctx := stream.Context() + ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("oathkeeper/middleware").Start(ctx, "Oathkeeper.UnaryInterceptor") + defer span.End() + + log := m.Logger().WithField("middleware", "oathkeeper") + + httpReq, err := m.httpRequest(ctx, info.FullMethod) + if err != nil { + log.WithError(err).Warn("could not build HTTP request") + span.SetAttributes(attribute.String("oathkeeper.verdict", "denied")) + span.SetStatus(codes.Error, err.Error()) + return ErrDenied + } + log = log.WithRequest(httpReq) - r, err := m.RuleMatcher().Match(ctx, httpReq.Method, httpReq.URL, rule.ProtocolGRPC) - if err != nil { - log.WithError(err).Warn("could not find a matching rule") - return ErrDenied - } + log.Debug("matching HTTP request build from gRPC") - _, err = m.ProxyRequestHandler().HandleRequest(httpReq, r) - if err != nil { - log.WithError(err).Warn("failed to handle request") - return ErrDenied - } + r, err := m.RuleMatcher().Match(ctx, httpReq.Method, httpReq.URL, rule.ProtocolGRPC) + if err != nil { + log.WithError(err).Warn("could not find a matching rule") + span.SetAttributes(attribute.String("oathkeeper.verdict", "denied")) + span.SetStatus(codes.Error, err.Error()) + return ErrDenied + } - log.Info("access request granted") - return handler(srv, stream) + _, err = m.ProxyRequestHandler().HandleRequest(httpReq, r) + if err != nil { + log.WithError(err).Warn("failed to handle request") + span.SetAttributes(attribute.String("oathkeeper.verdict", "denied")) + span.SetStatus(codes.Error, err.Error()) + return ErrDenied } + + log.Info("access request granted") + span.SetAttributes(attribute.String("oathkeeper.verdict", "allowed")) + span.End() + return handler(srv, stream) } diff --git a/middleware/grpc_middleware_test.go b/middleware/grpc_middleware_test.go index 7cc0be3d62..f6ebfaff8d 100644 --- a/middleware/grpc_middleware_test.go +++ b/middleware/grpc_middleware_test.go @@ -26,7 +26,6 @@ import ( "google.golang.org/grpc/status" "google.golang.org/grpc/test/bufconn" "google.golang.org/grpc/test/grpc_testing" - grpcTesting "google.golang.org/grpc/test/grpc_testing" "github.com/ory/oathkeeper/driver" "github.com/ory/oathkeeper/driver/configuration" @@ -34,7 +33,7 @@ import ( "github.com/ory/oathkeeper/rule" ) -func testClient(t *testing.T, l *bufconn.Listener, dialOpts ...grpc.DialOption) grpcTesting.TestServiceClient { +func testClient(t *testing.T, l *bufconn.Listener, dialOpts ...grpc.DialOption) grpc_testing.TestServiceClient { conn, err := grpc.Dial("bufnet", append(dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials()), @@ -45,7 +44,7 @@ func testClient(t *testing.T, l *bufconn.Listener, dialOpts ...grpc.DialOption) require.NoError(t, err) t.Cleanup(func() { conn.Close() }) - return grpcTesting.NewTestServiceClient(conn) + return grpc_testing.NewTestServiceClient(conn) } func testTokenCheckServer(t *testing.T) *httptest.Server { @@ -126,7 +125,7 @@ mutators: grpc.UnaryInterceptor(mw.UnaryInterceptor()), grpc.StreamInterceptor(mw.StreamInterceptor()), ) - grpcTesting.RegisterTestServiceServer(s, upstream) + grpc_testing.RegisterTestServiceServer(s, upstream) go func() { if err := s.Serve(l); err != nil { t.Logf("Server exited with error: %v", err) @@ -137,7 +136,7 @@ mutators: upstream.EXPECT(). EmptyCall(gomock.Any(), gomock.Any()). AnyTimes(). - Return(&grpcTesting.Empty{}, nil) + Return(&grpc_testing.Empty{}, nil) cases := []struct { name string @@ -337,13 +336,13 @@ mutators: require.NoError(t, reg.RuleRepository().SetMatchingStrategy(ctx, s)) require.NoError(t, reg.RuleRepository().Set(ctx, tc.rules[s])) - _, err := client.EmptyCall(ctx, &grpcTesting.Empty{}) + _, err := client.EmptyCall(ctx, &grpc_testing.Empty{}) tc.assert(t, err) - _, err = client.UnaryCall(ctx, &grpcTesting.SimpleRequest{}) + _, err = client.UnaryCall(ctx, &grpc_testing.SimpleRequest{}) assertErrDenied(t, err) - stream, _ := client.StreamingOutputCall(ctx, &grpcTesting.StreamingOutputCallRequest{}) + stream, _ := client.StreamingOutputCall(ctx, &grpc_testing.StreamingOutputCallRequest{}) _, err = stream.Recv() assertErrDenied(t, err) })