From 45e287ae1e00d6840c3b4661f9ab4d6e0d6a2346 Mon Sep 17 00:00:00 2001 From: Jason Parraga Date: Thu, 1 Aug 2024 15:19:05 -0700 Subject: [PATCH] [flyteadmin] Refactor panic recovery into middleware (#5546) * Refactor panic handling to middleware Signed-off-by: Jason Parraga * Remove registration of old panicCounter Signed-off-by: Jason Parraga * Add test coverage Signed-off-by: Jason Parraga --------- Signed-off-by: Jason Parraga --- flyteadmin/pkg/rpc/adminservice/attributes.go | 10 --- flyteadmin/pkg/rpc/adminservice/base.go | 13 --- flyteadmin/pkg/rpc/adminservice/base_test.go | 40 --------- .../rpc/adminservice/description_entity.go | 2 - flyteadmin/pkg/rpc/adminservice/execution.go | 10 --- .../pkg/rpc/adminservice/launch_plan.go | 7 -- flyteadmin/pkg/rpc/adminservice/metrics.go | 7 +- .../middleware/recovery_interceptor.go | 61 +++++++++++++ .../middleware/recovery_interceptor_test.go | 90 +++++++++++++++++++ .../pkg/rpc/adminservice/named_entity.go | 3 - .../pkg/rpc/adminservice/node_execution.go | 6 -- flyteadmin/pkg/rpc/adminservice/project.go | 5 -- flyteadmin/pkg/rpc/adminservice/task.go | 4 - .../pkg/rpc/adminservice/task_execution.go | 4 - flyteadmin/pkg/rpc/adminservice/version.go | 1 - flyteadmin/pkg/rpc/adminservice/workflow.go | 4 - flyteadmin/pkg/server/service.go | 29 +++++- 17 files changed, 177 insertions(+), 119 deletions(-) delete mode 100644 flyteadmin/pkg/rpc/adminservice/base_test.go create mode 100644 flyteadmin/pkg/rpc/adminservice/middleware/recovery_interceptor.go create mode 100644 flyteadmin/pkg/rpc/adminservice/middleware/recovery_interceptor_test.go diff --git a/flyteadmin/pkg/rpc/adminservice/attributes.go b/flyteadmin/pkg/rpc/adminservice/attributes.go index 46607da93e..62002a0e6e 100644 --- a/flyteadmin/pkg/rpc/adminservice/attributes.go +++ b/flyteadmin/pkg/rpc/adminservice/attributes.go @@ -12,7 +12,6 @@ import ( func (m *AdminService) UpdateWorkflowAttributes(ctx context.Context, request *admin.WorkflowAttributesUpdateRequest) ( *admin.WorkflowAttributesUpdateResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -30,7 +29,6 @@ func (m *AdminService) UpdateWorkflowAttributes(ctx context.Context, request *ad func (m *AdminService) GetWorkflowAttributes(ctx context.Context, request *admin.WorkflowAttributesGetRequest) ( *admin.WorkflowAttributesGetResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -48,7 +46,6 @@ func (m *AdminService) GetWorkflowAttributes(ctx context.Context, request *admin func (m *AdminService) DeleteWorkflowAttributes(ctx context.Context, request *admin.WorkflowAttributesDeleteRequest) ( *admin.WorkflowAttributesDeleteResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -66,7 +63,6 @@ func (m *AdminService) DeleteWorkflowAttributes(ctx context.Context, request *ad func (m *AdminService) UpdateProjectDomainAttributes(ctx context.Context, request *admin.ProjectDomainAttributesUpdateRequest) ( *admin.ProjectDomainAttributesUpdateResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -84,7 +80,6 @@ func (m *AdminService) UpdateProjectDomainAttributes(ctx context.Context, reques func (m *AdminService) GetProjectDomainAttributes(ctx context.Context, request *admin.ProjectDomainAttributesGetRequest) ( *admin.ProjectDomainAttributesGetResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -102,7 +97,6 @@ func (m *AdminService) GetProjectDomainAttributes(ctx context.Context, request * func (m *AdminService) DeleteProjectDomainAttributes(ctx context.Context, request *admin.ProjectDomainAttributesDeleteRequest) ( *admin.ProjectDomainAttributesDeleteResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -121,7 +115,6 @@ func (m *AdminService) DeleteProjectDomainAttributes(ctx context.Context, reques func (m *AdminService) UpdateProjectAttributes(ctx context.Context, request *admin.ProjectAttributesUpdateRequest) ( *admin.ProjectAttributesUpdateResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -140,7 +133,6 @@ func (m *AdminService) UpdateProjectAttributes(ctx context.Context, request *adm func (m *AdminService) GetProjectAttributes(ctx context.Context, request *admin.ProjectAttributesGetRequest) ( *admin.ProjectAttributesGetResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -159,7 +151,6 @@ func (m *AdminService) GetProjectAttributes(ctx context.Context, request *admin. func (m *AdminService) DeleteProjectAttributes(ctx context.Context, request *admin.ProjectAttributesDeleteRequest) ( *admin.ProjectAttributesDeleteResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -177,7 +168,6 @@ func (m *AdminService) DeleteProjectAttributes(ctx context.Context, request *adm func (m *AdminService) ListMatchableAttributes(ctx context.Context, request *admin.ListMatchableAttributesRequest) ( *admin.ListMatchableAttributesResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } diff --git a/flyteadmin/pkg/rpc/adminservice/base.go b/flyteadmin/pkg/rpc/adminservice/base.go index 5a2cb2ad89..8df2c595c7 100644 --- a/flyteadmin/pkg/rpc/adminservice/base.go +++ b/flyteadmin/pkg/rpc/adminservice/base.go @@ -5,8 +5,6 @@ import ( "fmt" "runtime/debug" - "github.com/golang/protobuf/proto" - "github.com/flyteorg/flyte/flyteadmin/pkg/async/cloudevent" eventWriter "github.com/flyteorg/flyte/flyteadmin/pkg/async/events/implementations" "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications" @@ -44,17 +42,6 @@ type AdminService struct { Metrics AdminMetrics } -// Intercepts all admin requests to handle panics during execution. -func (m *AdminService) interceptPanic(ctx context.Context, request proto.Message) { - err := recover() - if err == nil { - return - } - - m.Metrics.PanicCounter.Inc() - logger.Fatalf(ctx, "panic-ed for request: [%+v] with err: %v with Stack: %v", request, err, string(debug.Stack())) -} - const defaultRetries = 3 func NewAdminServer(ctx context.Context, pluginRegistry *plugins.Registry, configuration runtimeIfaces.Configuration, diff --git a/flyteadmin/pkg/rpc/adminservice/base_test.go b/flyteadmin/pkg/rpc/adminservice/base_test.go deleted file mode 100644 index 9b1cb626d5..0000000000 --- a/flyteadmin/pkg/rpc/adminservice/base_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package adminservice - -import ( - "context" - "testing" - - "github.com/golang/protobuf/proto" - "github.com/stretchr/testify/assert" - - "github.com/flyteorg/flyte/flytestdlib/logger" - "github.com/flyteorg/flyte/flytestdlib/promutils" -) - -func Test_interceptPanic(t *testing.T) { - m := AdminService{ - Metrics: InitMetrics(promutils.NewTestScope()), - } - - ctx := context.Background() - - // Mute logs to avoid .Fatal() (called in interceptPanic) causing the process to close - assert.NoError(t, logger.SetConfig(&logger.Config{Mute: true})) - - func() { - defer func() { - if err := recover(); err != nil { - assert.Fail(t, "Unexpected error", err) - } - }() - - a := func() { - defer m.interceptPanic(ctx, proto.Message(nil)) - - var x *int - *x = 10 - } - - a() - }() -} diff --git a/flyteadmin/pkg/rpc/adminservice/description_entity.go b/flyteadmin/pkg/rpc/adminservice/description_entity.go index 1d08234051..bc2d794aed 100644 --- a/flyteadmin/pkg/rpc/adminservice/description_entity.go +++ b/flyteadmin/pkg/rpc/adminservice/description_entity.go @@ -13,7 +13,6 @@ import ( ) func (m *AdminService) GetDescriptionEntity(ctx context.Context, request *admin.ObjectGetRequest) (*admin.DescriptionEntity, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -36,7 +35,6 @@ func (m *AdminService) GetDescriptionEntity(ctx context.Context, request *admin. } func (m *AdminService) ListDescriptionEntities(ctx context.Context, request *admin.DescriptionEntityListRequest) (*admin.DescriptionEntityList, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } diff --git a/flyteadmin/pkg/rpc/adminservice/execution.go b/flyteadmin/pkg/rpc/adminservice/execution.go index 919ed851a3..15caf5aa75 100644 --- a/flyteadmin/pkg/rpc/adminservice/execution.go +++ b/flyteadmin/pkg/rpc/adminservice/execution.go @@ -13,7 +13,6 @@ import ( func (m *AdminService) CreateExecution( ctx context.Context, request *admin.ExecutionCreateRequest) (*admin.ExecutionCreateResponse, error) { - defer m.interceptPanic(ctx, request) requestedAt := time.Now() if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") @@ -32,7 +31,6 @@ func (m *AdminService) CreateExecution( func (m *AdminService) RelaunchExecution( ctx context.Context, request *admin.ExecutionRelaunchRequest) (*admin.ExecutionCreateResponse, error) { - defer m.interceptPanic(ctx, request) requestedAt := time.Now() if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") @@ -51,7 +49,6 @@ func (m *AdminService) RelaunchExecution( func (m *AdminService) RecoverExecution( ctx context.Context, request *admin.ExecutionRecoverRequest) (*admin.ExecutionCreateResponse, error) { - defer m.interceptPanic(ctx, request) requestedAt := time.Now() if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") @@ -70,7 +67,6 @@ func (m *AdminService) RecoverExecution( func (m *AdminService) CreateWorkflowEvent( ctx context.Context, request *admin.WorkflowExecutionEventRequest) (*admin.WorkflowExecutionEventResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -89,7 +85,6 @@ func (m *AdminService) CreateWorkflowEvent( func (m *AdminService) GetExecution( ctx context.Context, request *admin.WorkflowExecutionGetRequest) (*admin.Execution, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -107,7 +102,6 @@ func (m *AdminService) GetExecution( func (m *AdminService) UpdateExecution( ctx context.Context, request *admin.ExecutionUpdateRequest) (*admin.ExecutionUpdateResponse, error) { - defer m.interceptPanic(ctx, request) requestedAt := time.Now() if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") @@ -126,7 +120,6 @@ func (m *AdminService) UpdateExecution( func (m *AdminService) GetExecutionData( ctx context.Context, request *admin.WorkflowExecutionGetDataRequest) (*admin.WorkflowExecutionGetDataResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -144,7 +137,6 @@ func (m *AdminService) GetExecutionData( func (m *AdminService) GetExecutionMetrics( ctx context.Context, request *admin.WorkflowExecutionGetMetricsRequest) (*admin.WorkflowExecutionGetMetricsResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -162,7 +154,6 @@ func (m *AdminService) GetExecutionMetrics( func (m *AdminService) ListExecutions( ctx context.Context, request *admin.ResourceListRequest) (*admin.ExecutionList, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -180,7 +171,6 @@ func (m *AdminService) ListExecutions( func (m *AdminService) TerminateExecution( ctx context.Context, request *admin.ExecutionTerminateRequest) (*admin.ExecutionTerminateResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } diff --git a/flyteadmin/pkg/rpc/adminservice/launch_plan.go b/flyteadmin/pkg/rpc/adminservice/launch_plan.go index ff3c2480e0..1586c3f542 100644 --- a/flyteadmin/pkg/rpc/adminservice/launch_plan.go +++ b/flyteadmin/pkg/rpc/adminservice/launch_plan.go @@ -14,7 +14,6 @@ import ( func (m *AdminService) CreateLaunchPlan( ctx context.Context, request *admin.LaunchPlanCreateRequest) (*admin.LaunchPlanCreateResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -31,7 +30,6 @@ func (m *AdminService) CreateLaunchPlan( } func (m *AdminService) GetLaunchPlan(ctx context.Context, request *admin.ObjectGetRequest) (*admin.LaunchPlan, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -55,7 +53,6 @@ func (m *AdminService) GetLaunchPlan(ctx context.Context, request *admin.ObjectG } func (m *AdminService) GetActiveLaunchPlan(ctx context.Context, request *admin.ActiveLaunchPlanRequest) (*admin.LaunchPlan, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -73,7 +70,6 @@ func (m *AdminService) GetActiveLaunchPlan(ctx context.Context, request *admin.A func (m *AdminService) UpdateLaunchPlan(ctx context.Context, request *admin.LaunchPlanUpdateRequest) ( *admin.LaunchPlanUpdateResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -97,7 +93,6 @@ func (m *AdminService) UpdateLaunchPlan(ctx context.Context, request *admin.Laun func (m *AdminService) ListLaunchPlans(ctx context.Context, request *admin.ResourceListRequest) ( *admin.LaunchPlanList, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Empty request. Please rephrase.") } @@ -116,7 +111,6 @@ func (m *AdminService) ListLaunchPlans(ctx context.Context, request *admin.Resou func (m *AdminService) ListActiveLaunchPlans(ctx context.Context, request *admin.ActiveLaunchPlanListRequest) ( *admin.LaunchPlanList, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Empty request. Please rephrase.") } @@ -135,7 +129,6 @@ func (m *AdminService) ListActiveLaunchPlans(ctx context.Context, request *admin func (m *AdminService) ListLaunchPlanIds(ctx context.Context, request *admin.NamedEntityIdentifierListRequest) ( *admin.NamedEntityIdentifierList, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Empty request. Please rephrase.") } diff --git a/flyteadmin/pkg/rpc/adminservice/metrics.go b/flyteadmin/pkg/rpc/adminservice/metrics.go index 65c6b741f3..f770665ef6 100644 --- a/flyteadmin/pkg/rpc/adminservice/metrics.go +++ b/flyteadmin/pkg/rpc/adminservice/metrics.go @@ -2,8 +2,6 @@ package adminservice import ( - "github.com/prometheus/client_golang/prometheus" - "github.com/flyteorg/flyte/flyteadmin/pkg/rpc/adminservice/util" "github.com/flyteorg/flyte/flytestdlib/promutils" ) @@ -115,8 +113,7 @@ type descriptionEntityEndpointMetrics struct { } type AdminMetrics struct { - Scope promutils.Scope - PanicCounter prometheus.Counter + Scope promutils.Scope executionEndpointMetrics executionEndpointMetrics launchPlanEndpointMetrics launchPlanEndpointMetrics @@ -137,8 +134,6 @@ type AdminMetrics struct { func InitMetrics(adminScope promutils.Scope) AdminMetrics { return AdminMetrics{ Scope: adminScope, - PanicCounter: adminScope.MustNewCounter("handler_panic", - "panics encountered while handling requests to the admin service"), executionEndpointMetrics: executionEndpointMetrics{ scope: adminScope, diff --git a/flyteadmin/pkg/rpc/adminservice/middleware/recovery_interceptor.go b/flyteadmin/pkg/rpc/adminservice/middleware/recovery_interceptor.go new file mode 100644 index 0000000000..a0a699a4f0 --- /dev/null +++ b/flyteadmin/pkg/rpc/adminservice/middleware/recovery_interceptor.go @@ -0,0 +1,61 @@ +package middleware + +import ( + "context" + "runtime/debug" + + "github.com/prometheus/client_golang/prometheus" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/flyteorg/flyte/flytestdlib/logger" + "github.com/flyteorg/flyte/flytestdlib/promutils" +) + +// RecoveryInterceptor is a struct for creating gRPC interceptors that handle panics in go +type RecoveryInterceptor struct { + panicCounter prometheus.Counter +} + +// NewRecoveryInterceptor creates a new RecoveryInterceptor with metrics under the provided scope +func NewRecoveryInterceptor(adminScope promutils.Scope) *RecoveryInterceptor { + panicCounter := adminScope.MustNewCounter("handler_panic", "panics encountered while handling gRPC requests") + return &RecoveryInterceptor{ + panicCounter: panicCounter, + } +} + +// UnaryServerInterceptor returns a new unary server interceptor for panic recovery. +func (ri *RecoveryInterceptor) UnaryServerInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ any, err error) { + + defer func() { + if r := recover(); r != nil { + ri.panicCounter.Inc() + logger.Errorf(ctx, "panic-ed for request: [%+v] to %s with err: %v with Stack: %v", req, info.FullMethod, r, string(debug.Stack())) + // Return INTERNAL to client with no info as to not leak implementation details + err = status.Errorf(codes.Internal, "") + } + }() + + return handler(ctx, req) + } +} + +// StreamServerInterceptor returns a new streaming server interceptor for panic recovery. +func (ri *RecoveryInterceptor) StreamServerInterceptor() grpc.StreamServerInterceptor { + return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { + + defer func() { + if r := recover(); r != nil { + ri.panicCounter.Inc() + logger.Errorf(stream.Context(), "panic-ed for stream to %s with err: %v with Stack: %v", info.FullMethod, r, string(debug.Stack())) + // Return INTERNAL to client with no info as to not leak implementation details + err = status.Errorf(codes.Internal, "") + } + }() + + return handler(srv, stream) + } +} diff --git a/flyteadmin/pkg/rpc/adminservice/middleware/recovery_interceptor_test.go b/flyteadmin/pkg/rpc/adminservice/middleware/recovery_interceptor_test.go new file mode 100644 index 0000000000..3928856067 --- /dev/null +++ b/flyteadmin/pkg/rpc/adminservice/middleware/recovery_interceptor_test.go @@ -0,0 +1,90 @@ +package middleware + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + + mockScope "github.com/flyteorg/flyte/flytestdlib/promutils" +) + +func TestRecoveryInterceptor(t *testing.T) { + ctx := context.Background() + testScope := mockScope.NewTestScope() + recoveryInterceptor := NewRecoveryInterceptor(testScope) + unaryInterceptor := recoveryInterceptor.UnaryServerInterceptor() + streamInterceptor := recoveryInterceptor.StreamServerInterceptor() + unaryInfo := &grpc.UnaryServerInfo{} + streamInfo := &grpc.StreamServerInfo{} + req := "test-request" + + t.Run("unary should recover from panic", func(t *testing.T) { + _, err := unaryInterceptor(ctx, req, unaryInfo, func(ctx context.Context, req any) (any, error) { + panic("synthetic") + }) + expectedErr := status.Errorf(codes.Internal, "") + require.Error(t, err) + require.Equal(t, expectedErr, err) + }) + + t.Run("stream should recover from panic", func(t *testing.T) { + stream := testStream{} + err := streamInterceptor(nil, &stream, streamInfo, func(srv any, stream grpc.ServerStream) error { + panic("synthetic") + }) + expectedErr := status.Errorf(codes.Internal, "") + require.Error(t, err) + require.Equal(t, expectedErr, err) + }) + + t.Run("unary should plumb response without panic", func(t *testing.T) { + mockedResponse := "test" + resp, err := unaryInterceptor(ctx, req, unaryInfo, func(ctx context.Context, req any) (any, error) { + return mockedResponse, nil + }) + require.NoError(t, err) + require.Equal(t, mockedResponse, resp) + }) + + t.Run("stream should plumb response without panic", func(t *testing.T) { + stream := testStream{} + handlerCalled := false + err := streamInterceptor(nil, &stream, streamInfo, func(srv any, stream grpc.ServerStream) error { + handlerCalled = true + return nil + }) + require.NoError(t, err) + require.True(t, handlerCalled) + }) +} + +// testStream is an implementation of grpc.ServerStream for testing. +type testStream struct { +} + +func (s *testStream) SendMsg(m interface{}) error { + return nil +} + +func (s *testStream) RecvMsg(m interface{}) error { + return nil +} + +func (s *testStream) SetHeader(metadata.MD) error { + return nil +} + +func (s *testStream) SendHeader(metadata.MD) error { + return nil +} + +func (s *testStream) SetTrailer(metadata.MD) {} + +func (s *testStream) Context() context.Context { + return context.Background() +} diff --git a/flyteadmin/pkg/rpc/adminservice/named_entity.go b/flyteadmin/pkg/rpc/adminservice/named_entity.go index d48a0485e2..4ef8f3ee0b 100644 --- a/flyteadmin/pkg/rpc/adminservice/named_entity.go +++ b/flyteadmin/pkg/rpc/adminservice/named_entity.go @@ -11,7 +11,6 @@ import ( ) func (m *AdminService) GetNamedEntity(ctx context.Context, request *admin.NamedEntityGetRequest) (*admin.NamedEntity, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -31,7 +30,6 @@ func (m *AdminService) GetNamedEntity(ctx context.Context, request *admin.NamedE func (m *AdminService) UpdateNamedEntity(ctx context.Context, request *admin.NamedEntityUpdateRequest) ( *admin.NamedEntityUpdateResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -50,7 +48,6 @@ func (m *AdminService) UpdateNamedEntity(ctx context.Context, request *admin.Nam func (m *AdminService) ListNamedEntities(ctx context.Context, request *admin.NamedEntityListRequest) ( *admin.NamedEntityList, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } diff --git a/flyteadmin/pkg/rpc/adminservice/node_execution.go b/flyteadmin/pkg/rpc/adminservice/node_execution.go index cf17e3ff70..1b187f3a35 100644 --- a/flyteadmin/pkg/rpc/adminservice/node_execution.go +++ b/flyteadmin/pkg/rpc/adminservice/node_execution.go @@ -14,7 +14,6 @@ import ( func (m *AdminService) CreateNodeEvent( ctx context.Context, request *admin.NodeExecutionEventRequest) (*admin.NodeExecutionEventResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -32,7 +31,6 @@ func (m *AdminService) CreateNodeEvent( func (m *AdminService) GetNodeExecution( ctx context.Context, request *admin.NodeExecutionGetRequest) (*admin.NodeExecution, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -49,7 +47,6 @@ func (m *AdminService) GetNodeExecution( } func (m *AdminService) GetDynamicNodeWorkflow(ctx context.Context, request *admin.GetDynamicNodeWorkflowRequest) (*admin.DynamicNodeWorkflowResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -68,7 +65,6 @@ func (m *AdminService) GetDynamicNodeWorkflow(ctx context.Context, request *admi func (m *AdminService) ListNodeExecutions( ctx context.Context, request *admin.NodeExecutionListRequest) (*admin.NodeExecutionList, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -86,7 +82,6 @@ func (m *AdminService) ListNodeExecutions( func (m *AdminService) ListNodeExecutionsForTask( ctx context.Context, request *admin.NodeExecutionForTaskListRequest) (*admin.NodeExecutionList, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -111,7 +106,6 @@ func (m *AdminService) ListNodeExecutionsForTask( func (m *AdminService) GetNodeExecutionData( ctx context.Context, request *admin.NodeExecutionGetDataRequest) (*admin.NodeExecutionGetDataResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } diff --git a/flyteadmin/pkg/rpc/adminservice/project.go b/flyteadmin/pkg/rpc/adminservice/project.go index 5e7352ad93..ab8d8e4375 100644 --- a/flyteadmin/pkg/rpc/adminservice/project.go +++ b/flyteadmin/pkg/rpc/adminservice/project.go @@ -12,7 +12,6 @@ import ( func (m *AdminService) RegisterProject(ctx context.Context, request *admin.ProjectRegisterRequest) ( *admin.ProjectRegisterResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -29,7 +28,6 @@ func (m *AdminService) RegisterProject(ctx context.Context, request *admin.Proje } func (m *AdminService) ListProjects(ctx context.Context, request *admin.ProjectListRequest) (*admin.Projects, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -48,7 +46,6 @@ func (m *AdminService) ListProjects(ctx context.Context, request *admin.ProjectL func (m *AdminService) UpdateProject(ctx context.Context, request *admin.Project) ( *admin.ProjectUpdateResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -65,7 +62,6 @@ func (m *AdminService) UpdateProject(ctx context.Context, request *admin.Project } func (m *AdminService) GetProject(ctx context.Context, request *admin.ProjectGetRequest) (*admin.Project, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -83,7 +79,6 @@ func (m *AdminService) GetProject(ctx context.Context, request *admin.ProjectGet } func (m *AdminService) GetDomains(ctx context.Context, request *admin.GetDomainRequest) (*admin.GetDomainsResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } diff --git a/flyteadmin/pkg/rpc/adminservice/task.go b/flyteadmin/pkg/rpc/adminservice/task.go index 8899480489..7db51ed2eb 100644 --- a/flyteadmin/pkg/rpc/adminservice/task.go +++ b/flyteadmin/pkg/rpc/adminservice/task.go @@ -15,7 +15,6 @@ import ( func (m *AdminService) CreateTask( ctx context.Context, request *admin.TaskCreateRequest) (*admin.TaskCreateResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -32,7 +31,6 @@ func (m *AdminService) CreateTask( } func (m *AdminService) GetTask(ctx context.Context, request *admin.ObjectGetRequest) (*admin.Task, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -56,7 +54,6 @@ func (m *AdminService) GetTask(ctx context.Context, request *admin.ObjectGetRequ func (m *AdminService) ListTaskIds( ctx context.Context, request *admin.NamedEntityIdentifierListRequest) (*admin.NamedEntityIdentifierList, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -74,7 +71,6 @@ func (m *AdminService) ListTaskIds( } func (m *AdminService) ListTasks(ctx context.Context, request *admin.ResourceListRequest) (*admin.TaskList, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } diff --git a/flyteadmin/pkg/rpc/adminservice/task_execution.go b/flyteadmin/pkg/rpc/adminservice/task_execution.go index 0561a1ba36..0638c02aa3 100644 --- a/flyteadmin/pkg/rpc/adminservice/task_execution.go +++ b/flyteadmin/pkg/rpc/adminservice/task_execution.go @@ -15,7 +15,6 @@ import ( func (m *AdminService) CreateTaskEvent( ctx context.Context, request *admin.TaskExecutionEventRequest) (*admin.TaskExecutionEventResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -34,7 +33,6 @@ func (m *AdminService) CreateTaskEvent( func (m *AdminService) GetTaskExecution( ctx context.Context, request *admin.TaskExecutionGetRequest) (*admin.TaskExecution, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -62,7 +60,6 @@ func (m *AdminService) GetTaskExecution( func (m *AdminService) ListTaskExecutions( ctx context.Context, request *admin.TaskExecutionListRequest) (*admin.TaskExecutionList, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Nil request") } @@ -84,7 +81,6 @@ func (m *AdminService) ListTaskExecutions( func (m *AdminService) GetTaskExecutionData( ctx context.Context, request *admin.TaskExecutionGetDataRequest) (*admin.TaskExecutionGetDataResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } diff --git a/flyteadmin/pkg/rpc/adminservice/version.go b/flyteadmin/pkg/rpc/adminservice/version.go index 7fb5861e50..3049a723aa 100644 --- a/flyteadmin/pkg/rpc/adminservice/version.go +++ b/flyteadmin/pkg/rpc/adminservice/version.go @@ -8,7 +8,6 @@ import ( func (m *AdminService) GetVersion(ctx context.Context, request *admin.GetVersionRequest) (*admin.GetVersionResponse, error) { - defer m.interceptPanic(ctx, request) response, err := m.VersionManager.GetVersion(ctx, request) if err != nil { return nil, err diff --git a/flyteadmin/pkg/rpc/adminservice/workflow.go b/flyteadmin/pkg/rpc/adminservice/workflow.go index 9fcf87c453..7f6ecc4c13 100644 --- a/flyteadmin/pkg/rpc/adminservice/workflow.go +++ b/flyteadmin/pkg/rpc/adminservice/workflow.go @@ -15,7 +15,6 @@ import ( func (m *AdminService) CreateWorkflow( ctx context.Context, request *admin.WorkflowCreateRequest) (*admin.WorkflowCreateResponse, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -32,7 +31,6 @@ func (m *AdminService) CreateWorkflow( } func (m *AdminService) GetWorkflow(ctx context.Context, request *admin.ObjectGetRequest) (*admin.Workflow, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -56,7 +54,6 @@ func (m *AdminService) GetWorkflow(ctx context.Context, request *admin.ObjectGet func (m *AdminService) ListWorkflowIds(ctx context.Context, request *admin.NamedEntityIdentifierListRequest) ( *admin.NamedEntityIdentifierList, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } @@ -75,7 +72,6 @@ func (m *AdminService) ListWorkflowIds(ctx context.Context, request *admin.Named } func (m *AdminService) ListWorkflows(ctx context.Context, request *admin.ResourceListRequest) (*admin.WorkflowList, error) { - defer m.interceptPanic(ctx, request) if request == nil { return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") } diff --git a/flyteadmin/pkg/server/service.go b/flyteadmin/pkg/server/service.go index ff80c343d3..bb09f9f615 100644 --- a/flyteadmin/pkg/server/service.go +++ b/flyteadmin/pkg/server/service.go @@ -12,6 +12,7 @@ import ( "github.com/gorilla/handlers" grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware" grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/auth" + grpcrecovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery" grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/pkg/errors" @@ -35,6 +36,7 @@ import ( "github.com/flyteorg/flyte/flyteadmin/pkg/config" "github.com/flyteorg/flyte/flyteadmin/pkg/rpc" "github.com/flyteorg/flyte/flyteadmin/pkg/rpc/adminservice" + "github.com/flyteorg/flyte/flyteadmin/pkg/rpc/adminservice/middleware" runtime2 "github.com/flyteorg/flyte/flyteadmin/pkg/runtime" runtimeIfaces "github.com/flyteorg/flyte/flyteadmin/pkg/runtime/interfaces" "github.com/flyteorg/flyte/flyteadmin/plugins" @@ -98,11 +100,18 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c otelgrpc.WithPropagators(propagation.TraceContext{}), ) + adminScope := scope.NewSubScope("admin") + recoveryInterceptor := middleware.NewRecoveryInterceptor(adminScope) + var chainedUnaryInterceptors grpc.UnaryServerInterceptor if cfg.Security.UseAuth { logger.Infof(ctx, "Creating gRPC server with authentication") middlewareInterceptors := plugins.Get[grpc.UnaryServerInterceptor](pluginRegistry, plugins.PluginIDUnaryServiceMiddleware) - chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(grpcprometheus.UnaryServerInterceptor, + chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer( + // recovery interceptor should always be first in order to handle any panics in the middleware or server + recoveryInterceptor.UnaryServerInterceptor(), + grpcrecovery.UnaryServerInterceptor(), + grpcprometheus.UnaryServerInterceptor, otelUnaryServerInterceptor, auth.GetAuthenticationCustomMetadataInterceptor(authCtx), grpcauth.UnaryServerInterceptor(auth.GetAuthenticationInterceptor(authCtx)), @@ -111,11 +120,23 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c ) } else { logger.Infof(ctx, "Creating gRPC server without authentication") - chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(grpcprometheus.UnaryServerInterceptor, otelUnaryServerInterceptor) + chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer( + // recovery interceptor should always be first in order to handle any panics in the middleware or server + recoveryInterceptor.UnaryServerInterceptor(), + grpcprometheus.UnaryServerInterceptor, + otelUnaryServerInterceptor, + ) } + chainedStreamInterceptors := grpcmiddleware.ChainStreamServer( + // recovery interceptor should always be first in order to handle any panics in the middleware or server + recoveryInterceptor.StreamServerInterceptor(), + grpcprometheus.StreamServerInterceptor, + ) + serverOpts := []grpc.ServerOption{ - grpc.StreamInterceptor(grpcprometheus.StreamServerInterceptor), + // recovery interceptor should always be first in order to handle any panics in the middleware or server + grpc.StreamInterceptor(chainedStreamInterceptors), grpc.UnaryInterceptor(chainedUnaryInterceptors), } if cfg.GrpcConfig.MaxMessageSizeBytes > 0 { @@ -131,7 +152,7 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c } configuration := runtime2.NewConfigurationProvider() - adminServer := adminservice.NewAdminServer(ctx, pluginRegistry, configuration, cfg.KubeConfig, cfg.Master, dataStorageClient, scope.NewSubScope("admin")) + adminServer := adminservice.NewAdminServer(ctx, pluginRegistry, configuration, cfg.KubeConfig, cfg.Master, dataStorageClient, adminScope) grpcService.RegisterAdminServiceServer(grpcServer, adminServer) if cfg.Security.UseAuth { grpcService.RegisterAuthMetadataServiceServer(grpcServer, authCtx.AuthMetadataService())