diff --git a/codegen/object.go b/codegen/object.go index c037344fe65..656af297a07 100644 --- a/codegen/object.go +++ b/codegen/object.go @@ -33,16 +33,17 @@ type Object struct { type Field struct { *Type - Description string // Description of a field - GQLName string // The name of the field in graphql - GoFieldType GoFieldType // The field type in go, if any - GoReceiverName string // The name of method & var receiver in go, if any - GoFieldName string // The name of the method or var in go, if any - Args []FieldArgument // A list of arguments to be passed to this field - ForceResolver bool // Should be emit Resolver method - NoErr bool // If this is bound to a go method, does that method have an error as the second argument - Object *Object // A link back to the parent object - Default interface{} // The default value + Description string // Description of a field + GQLName string // The name of the field in graphql + GoFieldType GoFieldType // The field type in go, if any + GoReceiverName string // The name of method & var receiver in go, if any + GoFieldName string // The name of the method or var in go, if any + Args []FieldArgument // A list of arguments to be passed to this field + ForceResolver bool // Should be emit Resolver method + MethodHasContext bool // If this is bound to a go method, does the method also take a context + NoErr bool // If this is bound to a go method, does that method have an error as the second argument + Object *Object // A link back to the parent object + Default interface{} // The default value } type FieldArgument struct { @@ -103,7 +104,10 @@ func (f *Field) IsVariable() bool { } func (f *Field) IsConcurrent() bool { - return f.IsResolver() && !f.Object.DisableConcurrency + if f.Object.DisableConcurrency { + return false + } + return f.MethodHasContext || f.IsResolver() } func (f *Field) GoNameExported() string { @@ -209,6 +213,10 @@ func (f *Field) CallArgs() string { if !f.Object.Root { args = append(args, "obj") } + } else { + if f.MethodHasContext { + args = append(args, "ctx") + } } for _, arg := range f.Args { diff --git a/codegen/testserver/generated.go b/codegen/testserver/generated.go index 8b7c196cd7f..8f9c7966b2a 100644 --- a/codegen/testserver/generated.go +++ b/codegen/testserver/generated.go @@ -34,6 +34,7 @@ type Config struct { type ResolverRoot interface { ForcedResolver() ForcedResolverResolver + ModelMethods() ModelMethodsResolver Query() QueryResolver Subscription() SubscriptionResolver User() UserResolver @@ -76,6 +77,12 @@ type ComplexityRoot struct { Id func(childComplexity int) int } + ModelMethods struct { + ResolverField func(childComplexity int) int + NoContext func(childComplexity int) int + WithContext func(childComplexity int) int + } + OuterObject struct { Inner func(childComplexity int) int } @@ -90,6 +97,7 @@ type ComplexityRoot struct { Keywords func(childComplexity int, input *Keywords) int Shapes func(childComplexity int) int ErrorBubble func(childComplexity int) int + ModelMethods func(childComplexity int) int Valid func(childComplexity int) int User func(childComplexity int, id int) int NullableArg func(childComplexity int, arg *int) int @@ -116,6 +124,9 @@ type ComplexityRoot struct { type ForcedResolverResolver interface { Field(ctx context.Context, obj *ForcedResolver) (*Circle, error) } +type ModelMethodsResolver interface { + ResolverField(ctx context.Context, obj *ModelMethods) (bool, error) +} type QueryResolver interface { InvalidIdentifier(ctx context.Context) (*invalid_packagename.InvalidIdentifier, error) Collision(ctx context.Context) (*introspection1.It, error) @@ -126,6 +137,7 @@ type QueryResolver interface { Keywords(ctx context.Context, input *Keywords) (bool, error) Shapes(ctx context.Context) ([]*Shape, error) ErrorBubble(ctx context.Context) (*Error, error) + ModelMethods(ctx context.Context) (*ModelMethods, error) Valid(ctx context.Context) (string, error) User(ctx context.Context, id int) (User, error) NullableArg(ctx context.Context, arg *int) (*string, error) @@ -648,6 +660,27 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.It.Id(childComplexity), true + case "ModelMethods.resolverField": + if e.complexity.ModelMethods.ResolverField == nil { + break + } + + return e.complexity.ModelMethods.ResolverField(childComplexity), true + + case "ModelMethods.noContext": + if e.complexity.ModelMethods.NoContext == nil { + break + } + + return e.complexity.ModelMethods.NoContext(childComplexity), true + + case "ModelMethods.withContext": + if e.complexity.ModelMethods.WithContext == nil { + break + } + + return e.complexity.ModelMethods.WithContext(childComplexity), true + case "OuterObject.inner": if e.complexity.OuterObject.Inner == nil { break @@ -738,6 +771,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.ErrorBubble(childComplexity), true + case "Query.modelMethods": + if e.complexity.Query.ModelMethods == nil { + break + } + + return e.complexity.Query.ModelMethods(childComplexity), true + case "Query.valid": if e.complexity.Query.Valid == nil { break @@ -1432,6 +1472,136 @@ func (ec *executionContext) _It_id(ctx context.Context, field graphql.CollectedF return graphql.MarshalID(res) } +var modelMethodsImplementors = []string{"ModelMethods"} + +// nolint: gocyclo, errcheck, gas, goconst +func (ec *executionContext) _ModelMethods(ctx context.Context, sel ast.SelectionSet, obj *ModelMethods) graphql.Marshaler { + fields := graphql.CollectFields(ctx, sel, modelMethodsImplementors) + + var wg sync.WaitGroup + out := graphql.NewOrderedMap(len(fields)) + invalid := false + for i, field := range fields { + out.Keys[i] = field.Alias + + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("ModelMethods") + case "resolverField": + wg.Add(1) + go func(i int, field graphql.CollectedField) { + out.Values[i] = ec._ModelMethods_resolverField(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalid = true + } + wg.Done() + }(i, field) + case "noContext": + out.Values[i] = ec._ModelMethods_noContext(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalid = true + } + case "withContext": + wg.Add(1) + go func(i int, field graphql.CollectedField) { + out.Values[i] = ec._ModelMethods_withContext(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalid = true + } + wg.Done() + }(i, field) + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + wg.Wait() + if invalid { + return graphql.Null + } + return out +} + +// nolint: vetshadow +func (ec *executionContext) _ModelMethods_resolverField(ctx context.Context, field graphql.CollectedField, obj *ModelMethods) graphql.Marshaler { + ctx = ec.Tracer.StartFieldExecution(ctx, field) + defer func() { ec.Tracer.EndFieldExecution(ctx) }() + rctx := &graphql.ResolverContext{ + Object: "ModelMethods", + Args: nil, + Field: field, + } + ctx = graphql.WithResolverContext(ctx, rctx) + ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx) + resTmp := ec.FieldMiddleware(ctx, obj, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.ModelMethods().ResolverField(rctx, obj) + }) + if resTmp == nil { + if !ec.HasError(rctx) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(bool) + rctx.Result = res + ctx = ec.Tracer.StartFieldChildExecution(ctx) + return graphql.MarshalBoolean(res) +} + +// nolint: vetshadow +func (ec *executionContext) _ModelMethods_noContext(ctx context.Context, field graphql.CollectedField, obj *ModelMethods) graphql.Marshaler { + ctx = ec.Tracer.StartFieldExecution(ctx, field) + defer func() { ec.Tracer.EndFieldExecution(ctx) }() + rctx := &graphql.ResolverContext{ + Object: "ModelMethods", + Args: nil, + Field: field, + } + ctx = graphql.WithResolverContext(ctx, rctx) + ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx) + resTmp := ec.FieldMiddleware(ctx, obj, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.NoContext(), nil + }) + if resTmp == nil { + if !ec.HasError(rctx) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(bool) + rctx.Result = res + ctx = ec.Tracer.StartFieldChildExecution(ctx) + return graphql.MarshalBoolean(res) +} + +// nolint: vetshadow +func (ec *executionContext) _ModelMethods_withContext(ctx context.Context, field graphql.CollectedField, obj *ModelMethods) graphql.Marshaler { + ctx = ec.Tracer.StartFieldExecution(ctx, field) + defer func() { ec.Tracer.EndFieldExecution(ctx) }() + rctx := &graphql.ResolverContext{ + Object: "ModelMethods", + Args: nil, + Field: field, + } + ctx = graphql.WithResolverContext(ctx, rctx) + ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx) + resTmp := ec.FieldMiddleware(ctx, obj, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.WithContext(ctx), nil + }) + if resTmp == nil { + if !ec.HasError(rctx) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(bool) + rctx.Result = res + ctx = ec.Tracer.StartFieldChildExecution(ctx) + return graphql.MarshalBoolean(res) +} + var outerObjectImplementors = []string{"OuterObject"} // nolint: gocyclo, errcheck, gas, goconst @@ -1566,6 +1736,12 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr out.Values[i] = ec._Query_errorBubble(ctx, field) wg.Done() }(i, field) + case "modelMethods": + wg.Add(1) + go func(i int, field graphql.CollectedField) { + out.Values[i] = ec._Query_modelMethods(ctx, field) + wg.Done() + }(i, field) case "valid": wg.Add(1) go func(i int, field graphql.CollectedField) { @@ -1989,6 +2165,35 @@ func (ec *executionContext) _Query_errorBubble(ctx context.Context, field graphq return ec._Error(ctx, field.Selections, res) } +// nolint: vetshadow +func (ec *executionContext) _Query_modelMethods(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { + ctx = ec.Tracer.StartFieldExecution(ctx, field) + defer func() { ec.Tracer.EndFieldExecution(ctx) }() + rctx := &graphql.ResolverContext{ + Object: "Query", + Args: nil, + Field: field, + } + ctx = graphql.WithResolverContext(ctx, rctx) + ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx) + resTmp := ec.FieldMiddleware(ctx, nil, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().ModelMethods(rctx) + }) + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*ModelMethods) + rctx.Result = res + ctx = ec.Tracer.StartFieldChildExecution(ctx) + + if res == nil { + return graphql.Null + } + + return ec._ModelMethods(ctx, field.Selections, res) +} + // nolint: vetshadow func (ec *executionContext) _Query_valid(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { ctx = ec.Tracer.StartFieldExecution(ctx, field) @@ -4204,6 +4409,7 @@ var parsedSchema = gqlparser.MustLoadSchema( keywords(input: Keywords): Boolean! shapes: [Shape] errorBubble: Error + modelMethods: ModelMethods valid: String! user(id: Int!): User! nullableArg(arg: Int = 123): String @@ -4226,6 +4432,12 @@ type Error { nilOnRequiredField: String! } +type ModelMethods { + resolverField: Boolean! + noContext: Boolean! + withContext: Boolean! +} + type InvalidIdentifier { id: Int! } diff --git a/codegen/testserver/generated_test.go b/codegen/testserver/generated_test.go index cbfd54715b1..2861f5358cd 100644 --- a/codegen/testserver/generated_test.go +++ b/codegen/testserver/generated_test.go @@ -519,6 +519,35 @@ func TestTracer(t *testing.T) { require.NoError(t, err) require.True(t, called) }) + + t.Run("model methods", func(t *testing.T) { + srv := httptest.NewServer( + handler.GraphQL( + NewExecutableSchema(Config{Resolvers: &testResolver{}}), + )) + defer srv.Close() + c := client.New(srv.URL) + t.Run("without context", func(t *testing.T) { + var resp struct { + ModelMethods struct { + NoContext bool + } + } + err := c.Post(`query { modelMethods{ noContext } }`, &resp) + require.NoError(t, err) + require.True(t, resp.ModelMethods.NoContext) + }) + t.Run("with context", func(t *testing.T) { + var resp struct { + ModelMethods struct { + WithContext bool + } + } + err := c.Post(`query { modelMethods{ withContext } }`, &resp) + require.NoError(t, err) + require.True(t, resp.ModelMethods.WithContext) + }) + }) } func TestResponseExtension(t *testing.T) { @@ -556,6 +585,15 @@ func (r *testResolver) User() UserResolver { func (r *testResolver) Query() QueryResolver { return &testQueryResolver{} } +func (r *testResolver) ModelMethods() ModelMethodsResolver { + return &testModelMethodsResolver{} +} + +type testModelMethodsResolver struct{} + +func (r *testModelMethodsResolver) ResolverField(ctx context.Context, obj *ModelMethods) (bool, error) { + return true, nil +} type testQueryResolver struct{ queryResolver } @@ -576,6 +614,10 @@ func (r *testQueryResolver) NullableArg(ctx context.Context, arg *int) (*string, return &s, nil } +func (r *testQueryResolver) ModelMethods(ctx context.Context) (*ModelMethods, error) { + return &ModelMethods{}, nil +} + func (r *testResolver) Subscription() SubscriptionResolver { return &testSubscriptionResolver{r} } diff --git a/codegen/testserver/gqlgen.yml b/codegen/testserver/gqlgen.yml index 32c93d2a868..7a6e8e5674e 100644 --- a/codegen/testserver/gqlgen.yml +++ b/codegen/testserver/gqlgen.yml @@ -11,6 +11,8 @@ resolver: models: It: model: "github.com/99designs/gqlgen/codegen/testserver/introspection.It" + ModelMethods: + model: "github.com/99designs/gqlgen/codegen/testserver.ModelMethods" InvalidIdentifier: model: "github.com/99designs/gqlgen/codegen/testserver/invalid-packagename.InvalidIdentifier" Changes: diff --git a/codegen/testserver/models.go b/codegen/testserver/models.go index f8cb15fd25c..8647046a607 100644 --- a/codegen/testserver/models.go +++ b/codegen/testserver/models.go @@ -1,11 +1,25 @@ package testserver -import "fmt" +import ( + context "context" + "fmt" +) type ForcedResolver struct { Field Circle } +type ModelMethods struct { +} + +func (m ModelMethods) NoContext() bool { + return true +} + +func (m ModelMethods) WithContext(_ context.Context) bool { + return true +} + type Error struct { ID string } diff --git a/codegen/testserver/resolver.go b/codegen/testserver/resolver.go index 4406ed374a5..da0b97591ad 100644 --- a/codegen/testserver/resolver.go +++ b/codegen/testserver/resolver.go @@ -14,6 +14,9 @@ type Resolver struct{} func (r *Resolver) ForcedResolver() ForcedResolverResolver { return &forcedResolverResolver{r} } +func (r *Resolver) ModelMethods() ModelMethodsResolver { + return &modelMethodsResolver{r} +} func (r *Resolver) Query() QueryResolver { return &queryResolver{r} } @@ -30,6 +33,12 @@ func (r *forcedResolverResolver) Field(ctx context.Context, obj *ForcedResolver) panic("not implemented") } +type modelMethodsResolver struct{ *Resolver } + +func (r *modelMethodsResolver) ResolverField(ctx context.Context, obj *ModelMethods) (bool, error) { + panic("not implemented") +} + type queryResolver struct{ *Resolver } func (r *queryResolver) InvalidIdentifier(ctx context.Context) (*invalid_packagename.InvalidIdentifier, error) { @@ -59,6 +68,9 @@ func (r *queryResolver) Shapes(ctx context.Context) ([]*Shape, error) { func (r *queryResolver) ErrorBubble(ctx context.Context) (*Error, error) { panic("not implemented") } +func (r *queryResolver) ModelMethods(ctx context.Context) (*ModelMethods, error) { + panic("not implemented") +} func (r *queryResolver) Valid(ctx context.Context) (string, error) { panic("not implemented") } diff --git a/codegen/testserver/schema.graphql b/codegen/testserver/schema.graphql index 02a69212482..276e7130be2 100644 --- a/codegen/testserver/schema.graphql +++ b/codegen/testserver/schema.graphql @@ -8,6 +8,7 @@ type Query { keywords(input: Keywords): Boolean! shapes: [Shape] errorBubble: Error + modelMethods: ModelMethods valid: String! user(id: Int!): User! nullableArg(arg: Int = 123): String @@ -30,6 +31,12 @@ type Error { nilOnRequiredField: String! } +type ModelMethods { + resolverField: Boolean! + noContext: Boolean! + withContext: Boolean! +} + type InvalidIdentifier { id: Int! } diff --git a/codegen/util.go b/codegen/util.go index 2ecd7c9e750..c6672bfc194 100644 --- a/codegen/util.go +++ b/codegen/util.go @@ -255,7 +255,19 @@ func bindMethod(imports *Imports, t types.Type, field *Field) error { } else if sig.Results().Len() != 2 { return fmt.Errorf("method has wrong number of args") } - newArgs, err := matchArgs(field, sig.Params()) + params := sig.Params() + // If the first argument is the context, remove it from the comparison and set + // the MethodHasContext flag so that the context will be passed to this model's method + if params.Len() > 0 && params.At(0).Type().String() == "context.Context" { + field.MethodHasContext = true + vars := make([]*types.Var, params.Len()-1) + for i := 1; i < params.Len(); i++ { + vars[i-1] = params.At(i) + } + params = types.NewTuple(vars...) + } + + newArgs, err := matchArgs(field, params) if err != nil { return err } diff --git a/docs/content/reference/resolvers.md b/docs/content/reference/resolvers.md index 0887f234805..ffac1046be6 100644 --- a/docs/content/reference/resolvers.md +++ b/docs/content/reference/resolvers.md @@ -87,7 +87,10 @@ models: ``` Here, we see that there is a method on car with the name ```Owner```, thus the ```Owner``` function will be called if -a graphQL request includes that field to be resolved +a graphQL request includes that field to be resolved. + +Model methods can optionally take a context as their first argument. If a +context is required, the model method will also be run in parallel. ## Bind when the field names do not match