Skip to content

Commit

Permalink
split context.go into 3 files
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Oct 31, 2019
1 parent 72c47c9 commit 9d1d77e
Show file tree
Hide file tree
Showing 12 changed files with 320 additions and 297 deletions.
91 changes: 2 additions & 89 deletions graphql/context.go → graphql/context_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@ import (
"github.com/vektah/gqlparser/gqlerror"
)

type Resolver func(ctx context.Context) (res interface{}, err error)
type ResultMiddleware func(ctx context.Context, next ResponseHandler) *Response
type OperationMiddleware func(ctx context.Context, next OperationHandler, writer Writer)
type FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error)
type ComplexityLimitFunc func(ctx context.Context) int

type RequestContext struct {
RawQuery string
Variables map[string]interface{}
Expand All @@ -24,17 +18,15 @@ type RequestContext struct {
OperationComplexity int
DisableIntrospection bool

// ErrorPresenter will be used to generate the error
// message from errors given to Error().
ErrorPresenter ErrorPresenterFunc
Recover RecoverFunc
ResolverMiddleware FieldMiddleware
DirectiveMiddleware FieldMiddleware
RequestMiddleware OperationInterceptor

Stats Stats
}

const requestCtx key = "request_context"

func (rc *RequestContext) Validate(ctx context.Context) error {
if rc.Doc == nil {
return errors.New("field 'Doc' must be required")
Expand All @@ -54,31 +46,13 @@ func (rc *RequestContext) Validate(ctx context.Context) error {
if rc.Recover == nil {
rc.Recover = DefaultRecover
}
if rc.ErrorPresenter == nil {
rc.ErrorPresenter = DefaultErrorPresenter
}
if rc.ComplexityLimit < 0 {
return errors.New("field 'ComplexityLimit' value must be 0 or more")
}

return nil
}

func DefaultResolverMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
return next(ctx)
}

func DefaultDirectiveMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
return next(ctx)
}

type key string

const (
requestCtx key = "request_context"
resolverCtx key = "resolver_context"
)

func GetRequestContext(ctx context.Context) *RequestContext {
if val, ok := ctx.Value(requestCtx).(*RequestContext); ok {
return val
Expand All @@ -90,53 +64,6 @@ func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context
return context.WithValue(ctx, requestCtx, rc)
}

type ResolverContext struct {
Parent *ResolverContext
// The name of the type this field belongs to
Object string
// These are the args after processing, they can be mutated in middleware to change what the resolver will get.
Args map[string]interface{}
// The raw field
Field CollectedField
// The index of array in path.
Index *int
// The result object of resolver
Result interface{}
// IsMethod indicates if the resolver is a method
IsMethod bool
}

func (r *ResolverContext) Path() []interface{} {
var path []interface{}
for it := r; it != nil; it = it.Parent {
if it.Index != nil {
path = append(path, *it.Index)
} else if it.Field.Field != nil {
path = append(path, it.Field.Alias)
}
}

// because we are walking up the chain, all the elements are backwards, do an inplace flip.
for i := len(path)/2 - 1; i >= 0; i-- {
opp := len(path) - 1 - i
path[i], path[opp] = path[opp], path[i]
}

return path
}

func GetResolverContext(ctx context.Context) *ResolverContext {
if val, ok := ctx.Value(resolverCtx).(*ResolverContext); ok {
return val
}
return nil
}

func WithResolverContext(ctx context.Context, rc *ResolverContext) context.Context {
rc.Parent = GetResolverContext(ctx)
return context.WithValue(ctx, resolverCtx, rc)
}

// This is just a convenient wrapper method for CollectFields
func CollectFieldsCtx(ctx context.Context, satisfies []string) []CollectedField {
resctx := GetResolverContext(ctx)
Expand Down Expand Up @@ -173,20 +100,6 @@ func (c *RequestContext) Error(ctx context.Context, err error) {
AddError(ctx, err)
}

func equalPath(a []interface{}, b []interface{}) bool {
if len(a) != len(b) {
return false
}

for i := 0; i < len(a); i++ {
if a[i] != b[i] {
return false
}
}

return true
}

var _ RequestContextMutator = ComplexityLimitFunc(nil)

func (c ComplexityLimitFunc) MutateRequestContext(ctx context.Context, rc *RequestContext) *gqlerror.Error {
Expand Down
68 changes: 68 additions & 0 deletions graphql/context_request_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package graphql

import (
"context"
"testing"

"github.com/stretchr/testify/require"
"github.com/vektah/gqlparser/ast"
)

func TestGetRequestContext(t *testing.T) {
require.Nil(t, GetRequestContext(context.Background()))

rc := &RequestContext{}
require.Equal(t, rc, GetRequestContext(WithRequestContext(context.Background(), rc)))
}

func TestCollectAllFields(t *testing.T) {
t.Run("collect fields", func(t *testing.T) {
ctx := testContext(ast.SelectionSet{
&ast.Field{
Name: "field",
},
})
s := CollectAllFields(ctx)
require.Equal(t, []string{"field"}, s)
})

t.Run("unique field names", func(t *testing.T) {
ctx := testContext(ast.SelectionSet{
&ast.Field{
Name: "field",
},
&ast.Field{
Name: "field",
Alias: "field alias",
},
})
s := CollectAllFields(ctx)
require.Equal(t, []string{"field"}, s)
})

t.Run("collect fragments", func(t *testing.T) {
ctx := testContext(ast.SelectionSet{
&ast.Field{
Name: "fieldA",
},
&ast.InlineFragment{
TypeCondition: "ExampleTypeA",
SelectionSet: ast.SelectionSet{
&ast.Field{
Name: "fieldA",
},
},
},
&ast.InlineFragment{
TypeCondition: "ExampleTypeB",
SelectionSet: ast.SelectionSet{
&ast.Field{
Name: "fieldB",
},
},
},
})
s := CollectAllFields(ctx)
require.Equal(t, []string{"fieldA", "fieldB"}, s)
})
}
78 changes: 78 additions & 0 deletions graphql/context_resolver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package graphql

import (
"context"
)

func DefaultResolverMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
return next(ctx)
}

func DefaultDirectiveMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
return next(ctx)
}

type key string

const resolverCtx key = "resolver_context"

type ResolverContext struct {
Parent *ResolverContext
// The name of the type this field belongs to
Object string
// These are the args after processing, they can be mutated in middleware to change what the resolver will get.
Args map[string]interface{}
// The raw field
Field CollectedField
// The index of array in path.
Index *int
// The result object of resolver
Result interface{}
// IsMethod indicates if the resolver is a method
IsMethod bool
}

func (r *ResolverContext) Path() []interface{} {
var path []interface{}
for it := r; it != nil; it = it.Parent {
if it.Index != nil {
path = append(path, *it.Index)
} else if it.Field.Field != nil {
path = append(path, it.Field.Alias)
}
}

// because we are walking up the chain, all the elements are backwards, do an inplace flip.
for i := len(path)/2 - 1; i >= 0; i-- {
opp := len(path) - 1 - i
path[i], path[opp] = path[opp], path[i]
}

return path
}

func GetResolverContext(ctx context.Context) *ResolverContext {
if val, ok := ctx.Value(resolverCtx).(*ResolverContext); ok {
return val
}
return nil
}

func WithResolverContext(ctx context.Context, rc *ResolverContext) context.Context {
rc.Parent = GetResolverContext(ctx)
return context.WithValue(ctx, resolverCtx, rc)
}

func equalPath(a []interface{}, b []interface{}) bool {
if len(a) != len(b) {
return false
}

for i := 0; i < len(a); i++ {
if a[i] != b[i] {
return false
}
}

return true
}
33 changes: 33 additions & 0 deletions graphql/context_resolver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package graphql

import (
"context"
"testing"

"github.com/stretchr/testify/require"
"github.com/vektah/gqlparser/ast"
)

func TestGetResolverContext(t *testing.T) {
require.Nil(t, GetResolverContext(context.Background()))

rc := &ResolverContext{}
require.Equal(t, rc, GetResolverContext(WithResolverContext(context.Background(), rc)))
}

func testContext(sel ast.SelectionSet) context.Context {

ctx := context.Background()

rqCtx := &RequestContext{}
ctx = WithRequestContext(ctx, rqCtx)

root := &ResolverContext{
Field: CollectedField{
Selections: sel,
},
}
ctx = WithResolverContext(ctx, root)

return ctx
}
21 changes: 16 additions & 5 deletions graphql/context_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,28 @@ import (
)

type responseContext struct {
errorPresenter ErrorPresenterFunc
recover RecoverFunc

errors gqlerror.List
errorsMu sync.Mutex

extensions map[string]interface{}
extensionsMu sync.Mutex
}

var resultCtx key = "result_context"
const resultCtx key = "result_context"

func getResponseContext(ctx context.Context) *responseContext {
val, _ := ctx.Value(resultCtx).(*responseContext)
return val
}

func WithResponseContext(ctx context.Context) context.Context {
return context.WithValue(ctx, resultCtx, &responseContext{})
func WithResponseContext(ctx context.Context, presenterFunc ErrorPresenterFunc, recoverFunc RecoverFunc) context.Context {
return context.WithValue(ctx, resultCtx, &responseContext{
errorPresenter: presenterFunc,
recover: recoverFunc,
})
}

// AddErrorf writes a formatted error to the client, first passing it through the error presenter.
Expand All @@ -34,7 +40,7 @@ func AddErrorf(ctx context.Context, format string, args ...interface{}) {
c.errorsMu.Lock()
defer c.errorsMu.Unlock()

c.errors = append(c.errors, GetRequestContext(ctx).ErrorPresenter(ctx, fmt.Errorf(format, args...)))
c.errors = append(c.errors, c.errorPresenter(ctx, fmt.Errorf(format, args...)))
}

// AddError sends an error to the client, first passing it through the error presenter.
Expand All @@ -44,7 +50,12 @@ func AddError(ctx context.Context, err error) {
c.errorsMu.Lock()
defer c.errorsMu.Unlock()

c.errors = append(c.errors, GetRequestContext(ctx).ErrorPresenter(ctx, err))
c.errors = append(c.errors, c.errorPresenter(ctx, err))
}

func Recover(ctx context.Context, err interface{}) (userMessage error) {
c := getResponseContext(ctx)
return c.recover(ctx, err)
}

// HasFieldError returns true if the given field has already errored
Expand Down
Loading

0 comments on commit 9d1d77e

Please sign in to comment.