Skip to content

Commit

Permalink
add fields to resolver context
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Mar 27, 2018
1 parent 40918d5 commit e700774
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 11 deletions.
2 changes: 1 addition & 1 deletion codegen/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (f *Field) CallArgs() string {
var args []string

if f.GoMethodName == "" {
args = append(args, "ctx")
args = append(args, "rctx")

if !f.Object.Root {
args = append(args, "obj")
Expand Down
2 changes: 2 additions & 0 deletions codegen/templates/field.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
{{- if $object.Stream }}
func (ec *executionContext) _{{$object.GQLType}}_{{$field.GQLName}}(ctx context.Context, field graphql.CollectedField) func() graphql.Marshaler {
{{- template "args.gotpl" $field.Args }}
rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{Field: field})
results, err := ec.resolvers.{{ $object.GQLType }}_{{ $field.GQLName }}({{ $field.CallArgs }})
if err != nil {
ec.Error(err)
Expand Down Expand Up @@ -47,6 +48,7 @@
}
{{- end }}
{{- else }}
rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{Field: field})
res, err := ec.resolvers.{{ $object.GQLType }}_{{ $field.GQLName }}({{ $field.CallArgs }})
if err != nil {
ec.Error(err)
Expand Down
26 changes: 23 additions & 3 deletions graphql/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@ type RequestContext struct {

type key string

const rcKey key = "request_context"
const (
request key = "request_context"
resolver key = "resolver_context"
)

func GetRequestContext(ctx context.Context) *RequestContext {
val := ctx.Value(rcKey)
val := ctx.Value(request)
if val == nil {
return nil
}
Expand All @@ -29,5 +32,22 @@ func GetRequestContext(ctx context.Context) *RequestContext {
}

func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context {
return context.WithValue(ctx, rcKey, rc)
return context.WithValue(ctx, request, rc)
}

type ResolverContext struct {
Field CollectedField
}

func GetResolverContext(ctx context.Context) *ResolverContext {
val := ctx.Value(resolver)
if val == nil {
return nil
}

return val.(*ResolverContext)
}

func WithResolverContext(ctx context.Context, rc *ResolverContext) context.Context {
return context.WithValue(ctx, request, rc)
}
21 changes: 14 additions & 7 deletions test/generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ func (ec *executionContext) _OuterObject_inner(ctx context.Context, field graphq
ret = graphql.Null
}
}()
res, err := ec.resolvers.OuterObject_inner(ctx, obj)
rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{Field: field})
res, err := ec.resolvers.OuterObject_inner(rctx, obj)
if err != nil {
ec.Error(err)
return graphql.Null
Expand Down Expand Up @@ -272,7 +273,8 @@ func (ec *executionContext) _Query_nestedInputs(ctx context.Context, field graph
ret = graphql.Null
}
}()
res, err := ec.resolvers.Query_nestedInputs(ctx, arg0)
rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{Field: field})
res, err := ec.resolvers.Query_nestedInputs(rctx, arg0)
if err != nil {
ec.Error(err)
return graphql.Null
Expand All @@ -293,7 +295,8 @@ func (ec *executionContext) _Query_nestedOutputs(ctx context.Context, field grap
ret = graphql.Null
}
}()
res, err := ec.resolvers.Query_nestedOutputs(ctx)
rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{Field: field})
res, err := ec.resolvers.Query_nestedOutputs(rctx)
if err != nil {
ec.Error(err)
return graphql.Null
Expand Down Expand Up @@ -321,7 +324,8 @@ func (ec *executionContext) _Query_shapes(ctx context.Context, field graphql.Col
ret = graphql.Null
}
}()
res, err := ec.resolvers.Query_shapes(ctx)
rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{Field: field})
res, err := ec.resolvers.Query_shapes(rctx)
if err != nil {
ec.Error(err)
return graphql.Null
Expand Down Expand Up @@ -357,7 +361,8 @@ func (ec *executionContext) _Query_recursive(ctx context.Context, field graphql.
ret = graphql.Null
}
}()
res, err := ec.resolvers.Query_recursive(ctx, arg0)
rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{Field: field})
res, err := ec.resolvers.Query_recursive(rctx, arg0)
if err != nil {
ec.Error(err)
return graphql.Null
Expand Down Expand Up @@ -392,7 +397,8 @@ func (ec *executionContext) _Query_mapInput(ctx context.Context, field graphql.C
ret = graphql.Null
}
}()
res, err := ec.resolvers.Query_mapInput(ctx, arg0)
rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{Field: field})
res, err := ec.resolvers.Query_mapInput(rctx, arg0)
if err != nil {
ec.Error(err)
return graphql.Null
Expand All @@ -413,7 +419,8 @@ func (ec *executionContext) _Query_collision(ctx context.Context, field graphql.
ret = graphql.Null
}
}()
res, err := ec.resolvers.Query_collision(ctx)
rctx := graphql.WithResolverContext(ctx, &graphql.ResolverContext{Field: field})
res, err := ec.resolvers.Query_collision(rctx)
if err != nil {
ec.Error(err)
return graphql.Null
Expand Down

0 comments on commit e700774

Please sign in to comment.