diff --git a/codegen/config/binder.go b/codegen/config/binder.go index bedc23bc61..212a168652 100644 --- a/codegen/config/binder.go +++ b/codegen/config/binder.go @@ -183,15 +183,16 @@ func (b *Binder) PointerTo(ref *TypeReference) *TypeReference { // TypeReference is used by args and field types. The Definition can refer to both input and output types. type TypeReference struct { - Definition *ast.Definition - GQL *ast.Type - GO types.Type // Type of the field being bound. Could be a pointer or a value type of Target. - Target types.Type // The actual type that we know how to bind to. May require pointer juggling when traversing to fields. - CastType types.Type // Before calling marshalling functions cast from/to this base type - Marshaler *types.Func // When using external marshalling functions this will point to the Marshal function - Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function - IsMarshaler bool // Does the type implement graphql.Marshaler and graphql.Unmarshaler - IsContext bool // Is the Marshaler/Unmarshaller the context version; applies to either the method or interface variety. + Definition *ast.Definition + GQL *ast.Type + GO types.Type // Type of the field being bound. Could be a pointer or a value type of Target. + Target types.Type // The actual type that we know how to bind to. May require pointer juggling when traversing to fields. + CastType types.Type // Before calling marshalling functions cast from/to this base type + Marshaler *types.Func // When using external marshalling functions this will point to the Marshal function + Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function + IsMarshaler bool // Does the type implement graphql.Marshaler and graphql.Unmarshaler + IsContext bool // Is the Marshaler/Unmarshaller the context version; applies to either the method or interface variety. + PointersInUmarshalInput bool // Inverse values and pointers in return. } func (ref *TypeReference) Elem() *TypeReference { @@ -412,6 +413,8 @@ func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret ref.GO = bindTarget } + ref.PointersInUmarshalInput = b.cfg.ReturnPointersInUmarshalInput + return ref, nil } diff --git a/codegen/config/config.go b/codegen/config/config.go index c6de8625f0..b13ee1a120 100644 --- a/codegen/config/config.go +++ b/codegen/config/config.go @@ -28,6 +28,7 @@ type Config struct { OmitSliceElementPointers bool `yaml:"omit_slice_element_pointers,omitempty"` OmitGetters bool `yaml:"omit_getters,omitempty"` StructFieldsAlwaysPointers bool `yaml:"struct_fields_always_pointers,omitempty"` + ReturnPointersInUmarshalInput bool `yaml:"return_pointers_in_unmarshalinput,omitempty"` ResolversAlwaysReturnPointers bool `yaml:"resolvers_always_return_pointers,omitempty"` SkipValidation bool `yaml:"skip_validation,omitempty"` SkipModTidy bool `yaml:"skip_mod_tidy,omitempty"` @@ -50,6 +51,7 @@ func DefaultConfig() *Config { Directives: map[string]DirectiveConfig{}, Models: TypeMap{}, StructFieldsAlwaysPointers: true, + ReturnPointersInUmarshalInput: false, ResolversAlwaysReturnPointers: true, } } diff --git a/codegen/input.gotpl b/codegen/input.gotpl index 116fe9ce76..85480d1f28 100644 --- a/codegen/input.gotpl +++ b/codegen/input.gotpl @@ -1,6 +1,10 @@ {{- range $input := .Inputs }} {{- if not .HasUnmarshal }} - func (ec *executionContext) unmarshalInput{{ .Name }}(ctx context.Context, obj interface{}) ({{.Type | ref}}, error) { + {{- $it := "it" }} + {{- if .PointersInUmarshalInput }} + {{- $it = "&it" }} + {{- end }} + func (ec *executionContext) unmarshalInput{{ .Name }}(ctx context.Context, obj interface{}) ({{ if .PointersInUmarshalInput }}*{{ end }}{{.Type | ref}}, error) { var it {{.Type | ref}} asMap := map[string]interface{}{} for k, v := range obj.(map[string]interface{}) { @@ -31,12 +35,12 @@ {{ template "implDirectives" $field }} tmp, err := directive{{$field.ImplDirectives|len}}(ctx) if err != nil { - return it, graphql.ErrorOnPath(ctx, err) + return {{$it}}, graphql.ErrorOnPath(ctx, err) } if data, ok := tmp.({{ $field.TypeReference.GO | ref }}) ; ok { {{- if $field.IsResolver }} if err = ec.resolvers.{{ $field.ShortInvocation }}; err != nil { - return it, err + return {{$it}}, err } {{- else }} it.{{$field.GoFieldName}} = data @@ -49,21 +53,21 @@ {{- end }} } else { err := fmt.Errorf(`unexpected type %T from directive, should be {{ $field.TypeReference.GO }}`, tmp) - return it, graphql.ErrorOnPath(ctx, err) + return {{$it}}, graphql.ErrorOnPath(ctx, err) } {{- else }} {{- if $field.IsResolver }} data, err := ec.{{ $field.TypeReference.UnmarshalFunc }}(ctx, v) if err != nil { - return it, err + return {{$it}}, err } if err = ec.resolvers.{{ $field.ShortInvocation }}; err != nil { - return it, err + return {{$it}}, err } {{- else }} it.{{$field.GoFieldName}}, err = ec.{{ $field.TypeReference.UnmarshalFunc }}(ctx, v) if err != nil { - return it, err + return {{$it}}, err } {{- end }} {{- end }} @@ -71,7 +75,7 @@ } } - return it, nil + return {{$it}}, nil } {{- end }} {{ end }} diff --git a/codegen/object.go b/codegen/object.go index a9cb34061b..ed0042a61f 100644 --- a/codegen/object.go +++ b/codegen/object.go @@ -25,14 +25,15 @@ const ( type Object struct { *ast.Definition - Type types.Type - ResolverInterface types.Type - Root bool - Fields []*Field - Implements []*ast.Definition - DisableConcurrency bool - Stream bool - Directives []*Directive + Type types.Type + ResolverInterface types.Type + Root bool + Fields []*Field + Implements []*ast.Definition + DisableConcurrency bool + Stream bool + Directives []*Directive + PointersInUmarshalInput bool } func (b *builder) buildObject(typ *ast.Definition) (*Object, error) { @@ -42,11 +43,12 @@ func (b *builder) buildObject(typ *ast.Definition) (*Object, error) { } caser := cases.Title(language.English, cases.NoLower) obj := &Object{ - Definition: typ, - Root: b.Schema.Query == typ || b.Schema.Mutation == typ || b.Schema.Subscription == typ, - DisableConcurrency: typ == b.Schema.Mutation, - Stream: typ == b.Schema.Subscription, - Directives: dirs, + Definition: typ, + Root: b.Schema.Query == typ || b.Schema.Mutation == typ || b.Schema.Subscription == typ, + DisableConcurrency: typ == b.Schema.Mutation, + Stream: typ == b.Schema.Subscription, + Directives: dirs, + PointersInUmarshalInput: b.Config.ReturnPointersInUmarshalInput, ResolverInterface: types.NewNamed( types.NewTypeName(0, b.Config.Exec.Pkg(), caser.String(typ.Name)+"Resolver", nil), nil, diff --git a/codegen/type.gotpl b/codegen/type.gotpl index d5c3919588..eaa1718436 100644 --- a/codegen/type.gotpl +++ b/codegen/type.gotpl @@ -75,9 +75,11 @@ return res, graphql.ErrorOnPath(ctx, err) {{- else }} res, err := ec.unmarshalInput{{ $type.GQL.Name }}(ctx, v) - {{- if $type.IsNilable }} + {{- if and $type.IsNilable (not $type.PointersInUmarshalInput) }} return &res, graphql.ErrorOnPath(ctx, err) - {{- else}} + {{- else if and (not $type.IsNilable) $type.PointersInUmarshalInput }} + return *res, graphql.ErrorOnPath(ctx, err) + {{- else }} return res, graphql.ErrorOnPath(ctx, err) {{- end }} {{- end }} diff --git a/docs/content/config.md b/docs/content/config.md index a274947d56..d9e2d8f07c 100644 --- a/docs/content/config.md +++ b/docs/content/config.md @@ -51,6 +51,9 @@ resolver: # Optional: turn off to make resolvers return values instead of pointers for structs # resolvers_always_return_pointers: true +# Optional: turn on to return pointers instead of values in unmarshalInput +# return_pointers_in_unmarshalinput: false + # Optional: set to speed up generation time by not performing a final validation pass. # skip_validation: true