From cc4e0ba28375e0ec39c586485fa138c28e5bfcba Mon Sep 17 00:00:00 2001 From: Chris Cerk Date: Tue, 29 Aug 2023 07:11:27 -0400 Subject: [PATCH] ensure HasOperationContext checks for nil (#2776) --- graphql/context_operation.go | 4 ++-- graphql/context_operation_test.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/graphql/context_operation.go b/graphql/context_operation.go index 77a42b84b7..3e6a221b0b 100644 --- a/graphql/context_operation.go +++ b/graphql/context_operation.go @@ -73,8 +73,8 @@ func WithOperationContext(ctx context.Context, rc *OperationContext) context.Con // // Some errors can happen outside of an operation, eg json unmarshal errors. func HasOperationContext(ctx context.Context) bool { - _, ok := ctx.Value(operationCtx).(*OperationContext) - return ok + val, ok := ctx.Value(operationCtx).(*OperationContext) + return ok && val != nil } // This is just a convenient wrapper method for CollectFields diff --git a/graphql/context_operation_test.go b/graphql/context_operation_test.go index 4ce374601f..cd4e61e716 100644 --- a/graphql/context_operation_test.go +++ b/graphql/context_operation_test.go @@ -3,11 +3,33 @@ package graphql import ( "context" "testing" + "time" "github.com/stretchr/testify/require" "github.com/vektah/gqlparser/v2/ast" ) +// implement context.Context interface +type testGraphRequestContext struct { + opContext *OperationContext +} + +func (t *testGraphRequestContext) Deadline() (deadline time.Time, ok bool) { + return time.Time{}, false +} + +func (t *testGraphRequestContext) Done() <-chan struct{} { + return nil +} + +func (t *testGraphRequestContext) Err() error { + return nil +} + +func (t *testGraphRequestContext) Value(key interface{}) interface{} { + return t.opContext +} + func TestGetOperationContext(t *testing.T) { rc := &OperationContext{} @@ -26,6 +48,15 @@ func TestGetOperationContext(t *testing.T) { GetOperationContext(ctx) }) }) + + t.Run("with nil operation context", func(t *testing.T) { + ctx := &testGraphRequestContext{opContext: nil} + + require.False(t, HasOperationContext(ctx)) + require.Panics(t, func() { + GetOperationContext(ctx) + }) + }) } func TestCollectAllFields(t *testing.T) {