Skip to content

Commit

Permalink
move request scoped data into context
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Mar 27, 2018
1 parent 4e13262 commit 40918d5
Show file tree
Hide file tree
Showing 12 changed files with 307 additions and 293 deletions.
4 changes: 2 additions & 2 deletions codegen/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (f *Field) CallArgs() string {
var args []string

if f.GoMethodName == "" {
args = append(args, "ec.ctx")
args = append(args, "ctx")

if !f.Object.Root {
args = append(args, "obj")
Expand Down Expand Up @@ -134,7 +134,7 @@ func (f *Field) doWriteJson(val string, remainingMods []string, isPtr bool, dept
if !isPtr {
val = "&" + val
}
return fmt.Sprintf("return ec._%s(field.Selections, %s)", f.GQLType, val)
return fmt.Sprintf("return ec._%s(ctx, field.Selections, %s)", f.GQLType, val)
}
}

Expand Down
6 changes: 3 additions & 3 deletions codegen/templates/field.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
{{ $object := $field.Object }}

{{- if $object.Stream }}
func (ec *executionContext) _{{$object.GQLType}}_{{$field.GQLName}}(field graphql.CollectedField) func() graphql.Marshaler {
func (ec *executionContext) _{{$object.GQLType}}_{{$field.GQLName}}(ctx context.Context, field graphql.CollectedField) func() graphql.Marshaler {
{{- template "args.gotpl" $field.Args }}
results, err := ec.resolvers.{{ $object.GQLType }}_{{ $field.GQLName }}({{ $field.CallArgs }})
if err != nil {
Expand All @@ -20,14 +20,14 @@
}
}
{{ else }}
func (ec *executionContext) _{{$object.GQLType}}_{{$field.GQLName}}(field graphql.CollectedField, {{if not $object.Root}}obj *{{$object.FullName}}{{end}}) graphql.Marshaler {
func (ec *executionContext) _{{$object.GQLType}}_{{$field.GQLName}}(ctx context.Context, field graphql.CollectedField, {{if not $object.Root}}obj *{{$object.FullName}}{{end}}) graphql.Marshaler {
{{- template "args.gotpl" $field.Args }}

{{- if $field.IsConcurrent }}
return graphql.Defer(func() (ret graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
userErr := ec.recover(r)
userErr := ec.Recover(r)
ec.Error(userErr)
ret = graphql.Null
}
Expand Down
49 changes: 13 additions & 36 deletions codegen/templates/generated.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,8 @@ import (
{{ end }}
)

func MakeExecutableSchema(resolvers Resolvers, opts ...ExecutableOption) graphql.ExecutableSchema {
ret := &executableSchema{resolvers: resolvers}
for _, opt := range opts {
opt(ret)
}
return ret
func MakeExecutableSchema(resolvers Resolvers) graphql.ExecutableSchema {
return &executableSchema{resolvers: resolvers}
}

type Resolvers interface {
Expand All @@ -24,28 +20,19 @@ type Resolvers interface {
{{- end }}
}

type ExecutableOption func(*executableSchema)

func WithErrorConverter(fn func(error) string) ExecutableOption {
return func(s *executableSchema) {
s.errorMessageFn = fn
}
}

type executableSchema struct {
resolvers Resolvers
errorMessageFn func(error) string
}

func (e *executableSchema) Schema() *schema.Schema {
return parsedSchema
}

func (e *executableSchema) Query(ctx context.Context, doc *query.Document, variables map[string]interface{}, op *query.Operation, recover graphql.RecoverFunc) *graphql.Response {
func (e *executableSchema) Query(ctx context.Context, op *query.Operation) *graphql.Response {
{{- if .QueryRoot }}
ec := e.makeExecutionContext(ctx, doc, variables, recover)
ec := executionContext{graphql.GetRequestContext(ctx), e.resolvers}

data := ec._{{.QueryRoot.GQLType}}(op.Selections)
data := ec._{{.QueryRoot.GQLType}}(ctx, op.Selections)
var buf bytes.Buffer
data.MarshalGQL(&buf)

Expand All @@ -58,11 +45,11 @@ func (e *executableSchema) Query(ctx context.Context, doc *query.Document, varia
{{- end }}
}

func (e *executableSchema) Mutation(ctx context.Context, doc *query.Document, variables map[string]interface{}, op *query.Operation, recover graphql.RecoverFunc) *graphql.Response {
func (e *executableSchema) Mutation(ctx context.Context, op *query.Operation) *graphql.Response {
{{- if .MutationRoot }}
ec := e.makeExecutionContext(ctx, doc, variables, recover)
ec := executionContext{graphql.GetRequestContext(ctx), e.resolvers}

data := ec._{{.MutationRoot.GQLType}}(op.Selections)
data := ec._{{.MutationRoot.GQLType}}(ctx, op.Selections)
var buf bytes.Buffer
data.MarshalGQL(&buf)

Expand All @@ -75,11 +62,11 @@ func (e *executableSchema) Mutation(ctx context.Context, doc *query.Document, va
{{- end }}
}

func (e *executableSchema) Subscription(ctx context.Context, doc *query.Document, variables map[string]interface{}, op *query.Operation, recover graphql.RecoverFunc) func() *graphql.Response {
func (e *executableSchema) Subscription(ctx context.Context, op *query.Operation) func() *graphql.Response {
{{- if .SubscriptionRoot }}
ec := e.makeExecutionContext(ctx, doc, variables, recover)
ec := executionContext{graphql.GetRequestContext(ctx), e.resolvers}

next := ec._{{.SubscriptionRoot.GQLType}}(op.Selections)
next := ec._{{.SubscriptionRoot.GQLType}}(ctx, op.Selections)
if ec.Errors != nil {
return graphql.OneShot(&graphql.Response{Data: []byte("null"), Errors: ec.Errors})
}
Expand All @@ -105,20 +92,10 @@ func (e *executableSchema) Subscription(ctx context.Context, doc *query.Document
{{- end }}
}

func (e *executableSchema) makeExecutionContext(ctx context.Context, doc *query.Document, variables map[string]interface{}, recover graphql.RecoverFunc) *executionContext {
errBuilder := errors.Builder{ErrorMessageFn: e.errorMessageFn}
return &executionContext{
Builder: errBuilder, resolvers: e.resolvers, variables: variables, doc: doc, ctx: ctx, recover: recover,
}
}

type executionContext struct {
errors.Builder
*graphql.RequestContext

resolvers Resolvers
variables map[string]interface{}
doc *query.Document
ctx context.Context
recover graphql.RecoverFunc
}

{{- range $object := .Objects }}
Expand Down
6 changes: 3 additions & 3 deletions codegen/templates/interface.gotpl
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
{{- $interface := . }}

func (ec *executionContext) _{{$interface.GQLType}}(sel []query.Selection, obj *{{$interface.FullName}}) graphql.Marshaler {
func (ec *executionContext) _{{$interface.GQLType}}(ctx context.Context, sel []query.Selection, obj *{{$interface.FullName}}) graphql.Marshaler {
switch obj := (*obj).(type) {
case nil:
return graphql.Null
{{- range $implementor := $interface.Implementors }}
{{- if $implementor.ValueReceiver }}
case {{$implementor.FullName}}:
return ec._{{$implementor.GQLType}}(sel, &obj)
return ec._{{$implementor.GQLType}}(ctx, sel, &obj)
{{- end}}
case *{{$implementor.FullName}}:
return ec._{{$implementor.GQLType}}(sel, obj)
return ec._{{$implementor.GQLType}}(ctx, sel, obj)
{{- end }}
default:
panic(fmt.Errorf("unexpected type %T", obj))
Expand Down
12 changes: 6 additions & 6 deletions codegen/templates/object.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ var {{ $object.GQLType|lcFirst}}Implementors = {{$object.Implementors}}

// nolint: gocyclo, errcheck, gas, goconst
{{- if .Stream }}
func (ec *executionContext) _{{$object.GQLType}}(sel []query.Selection) func() graphql.Marshaler {
fields := graphql.CollectFields(ec.doc, sel, {{$object.GQLType|lcFirst}}Implementors, ec.variables)
func (ec *executionContext) _{{$object.GQLType}}(ctx context.Context, sel []query.Selection) func() graphql.Marshaler {
fields := graphql.CollectFields(ec.Doc, sel, {{$object.GQLType|lcFirst}}Implementors, ec.Variables)

if len(fields) != 1 {
ec.Errorf("must subscribe to exactly one stream")
Expand All @@ -15,15 +15,15 @@ func (ec *executionContext) _{{$object.GQLType}}(sel []query.Selection) func() g
switch fields[0].Name {
{{- range $field := $object.Fields }}
case "{{$field.GQLName}}":
return ec._{{$object.GQLType}}_{{$field.GQLName}}(fields[0])
return ec._{{$object.GQLType}}_{{$field.GQLName}}(ctx, fields[0])
{{- end }}
default:
panic("unknown field " + strconv.Quote(fields[0].Name))
}
}
{{- else }}
func (ec *executionContext) _{{$object.GQLType}}(sel []query.Selection{{if not $object.Root}}, obj *{{$object.FullName}} {{end}}) graphql.Marshaler {
fields := graphql.CollectFields(ec.doc, sel, {{$object.GQLType|lcFirst}}Implementors, ec.variables)
func (ec *executionContext) _{{$object.GQLType}}(ctx context.Context, sel []query.Selection{{if not $object.Root}}, obj *{{$object.FullName}} {{end}}) graphql.Marshaler {
fields := graphql.CollectFields(ec.Doc, sel, {{$object.GQLType|lcFirst}}Implementors, ec.Variables)
out := graphql.NewOrderedMap(len(fields))
for i, field := range fields {
out.Keys[i] = field.Alias
Expand All @@ -33,7 +33,7 @@ func (ec *executionContext) _{{$object.GQLType}}(sel []query.Selection{{if not $
out.Values[i] = graphql.MarshalString({{$object.GQLType|quote}})
{{- range $field := $object.Fields }}
case "{{$field.GQLName}}":
out.Values[i] = ec._{{$object.GQLType}}_{{$field.GQLName}}(field{{if not $object.Root}}, obj{{end}})
out.Values[i] = ec._{{$object.GQLType}}_{{$field.GQLName}}(ctx, field{{if not $object.Root}}, obj{{end}})
{{- end }}
default:
panic("unknown field " + strconv.Quote(field.Name))
Expand Down
33 changes: 33 additions & 0 deletions graphql/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package graphql

import (
"context"

"github.com/vektah/gqlgen/neelance/errors"
"github.com/vektah/gqlgen/neelance/query"
)

type RequestContext struct {
errors.Builder

Variables map[string]interface{}
Doc *query.Document
Recover RecoverFunc
}

type key string

const rcKey key = "request_context"

func GetRequestContext(ctx context.Context) *RequestContext {
val := ctx.Value(rcKey)
if val == nil {
return nil
}

return val.(*RequestContext)
}

func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context {
return context.WithValue(ctx, rcKey, rc)
}
6 changes: 3 additions & 3 deletions graphql/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import (
type ExecutableSchema interface {
Schema() *schema.Schema

Query(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation, recover RecoverFunc) *Response
Mutation(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation, recover RecoverFunc) *Response
Subscription(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation, recover RecoverFunc) func() *Response
Query(ctx context.Context, op *query.Operation) *Response
Mutation(ctx context.Context, op *query.Operation) *Response
Subscription(ctx context.Context, op *query.Operation) func() *Response
}

func CollectFields(doc *query.Document, selSet []query.Selection, satisfies []string, variables map[string]interface{}) []CollectedField {
Expand Down
24 changes: 20 additions & 4 deletions handler/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ type params struct {
}

type Config struct {
upgrader websocket.Upgrader
recover graphql.RecoverFunc
upgrader websocket.Upgrader
recover graphql.RecoverFunc
formatError func(error) string
}

type Option func(cfg *Config)
Expand All @@ -39,6 +40,12 @@ func RecoverFunc(recover graphql.RecoverFunc) Option {
}
}

func FormatErrorFunc(f func(error) string) Option {
return func(cfg *Config) {
cfg.formatError = f
}
}

func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc {
cfg := Config{
recover: graphql.DefaultRecoverFunc,
Expand Down Expand Up @@ -96,15 +103,24 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc
return
}

ctx := graphql.WithRequestContext(r.Context(), &graphql.RequestContext{
Doc: doc,
Variables: reqParams.Variables,
Recover: cfg.recover,
Builder: errors.Builder{
ErrorMessageFn: cfg.formatError,
},
})

switch op.Type {
case query.Query:
b, err := json.Marshal(exec.Query(r.Context(), doc, reqParams.Variables, op, cfg.recover))
b, err := json.Marshal(exec.Query(ctx, op))
if err != nil {
panic(err)
}
w.Write(b)
case query.Mutation:
b, err := json.Marshal(exec.Mutation(r.Context(), doc, reqParams.Variables, op, cfg.recover))
b, err := json.Marshal(exec.Mutation(ctx, op))
if err != nil {
panic(err)
}
Expand Down
6 changes: 3 additions & 3 deletions handler/stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ func (e *executableSchemaStub) Schema() *schema.Schema {
`)
}

func (e *executableSchemaStub) Query(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation, recover graphql.RecoverFunc) *graphql.Response {
func (e *executableSchemaStub) Query(ctx context.Context, op *query.Operation) *graphql.Response {
return &graphql.Response{Data: []byte(`{"name":"test"}`)}
}

func (e *executableSchemaStub) Mutation(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation, recover graphql.RecoverFunc) *graphql.Response {
func (e *executableSchemaStub) Mutation(ctx context.Context, op *query.Operation) *graphql.Response {
return &graphql.Response{
Errors: []*errors.QueryError{{Message: "mutations are not supported"}},
}
}

func (e *executableSchemaStub) Subscription(ctx context.Context, document *query.Document, variables map[string]interface{}, op *query.Operation, recover graphql.RecoverFunc) func() *graphql.Response {
func (e *executableSchemaStub) Subscription(ctx context.Context, op *query.Operation) func() *graphql.Response {
return func() *graphql.Response {
time.Sleep(20 * time.Millisecond)
select {
Expand Down
30 changes: 20 additions & 10 deletions handler/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@ type operationMessage struct {
}

type wsConnection struct {
ctx context.Context
conn *websocket.Conn
exec graphql.ExecutableSchema
active map[string]context.CancelFunc
mu sync.Mutex
recover graphql.RecoverFunc
ctx context.Context
conn *websocket.Conn
exec graphql.ExecutableSchema
active map[string]context.CancelFunc
mu sync.Mutex
recover graphql.RecoverFunc
formatError func(err error) string
}

func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, upgrader websocket.Upgrader, recover graphql.RecoverFunc) {
Expand Down Expand Up @@ -155,20 +156,29 @@ func (c *wsConnection) subscribe(message *operationMessage) bool {
return true
}

ctx := graphql.WithRequestContext(c.ctx, &graphql.RequestContext{
Doc: doc,
Variables: reqParams.Variables,
Recover: c.recover,
Builder: errors.Builder{
ErrorMessageFn: c.formatError,
},
})

if op.Type != query.Subscription {
var result *graphql.Response
if op.Type == query.Query {
result = c.exec.Query(c.ctx, doc, reqParams.Variables, op, c.recover)
result = c.exec.Query(ctx, op)
} else {
result = c.exec.Mutation(c.ctx, doc, reqParams.Variables, op, c.recover)
result = c.exec.Mutation(ctx, op)
}

c.sendData(message.ID, result)
c.write(&operationMessage{ID: message.ID, Type: completeMsg})
return true
}

ctx, cancel := context.WithCancel(c.ctx)
ctx, cancel := context.WithCancel(ctx)
c.mu.Lock()
c.active[message.ID] = cancel
c.mu.Unlock()
Expand All @@ -179,7 +189,7 @@ func (c *wsConnection) subscribe(message *operationMessage) bool {
c.sendError(message.ID, &errors.QueryError{Message: userErr.Error()})
}
}()
next := c.exec.Subscription(ctx, doc, reqParams.Variables, op, c.recover)
next := c.exec.Subscription(ctx, op)
for result := next(); result != nil; result = next() {
fmt.Println(result)
c.sendData(message.ID, result)
Expand Down
Loading

0 comments on commit 40918d5

Please sign in to comment.