diff --git a/codegen/build.go b/codegen/build.go index 6cbf077d39e..27383982c69 100644 --- a/codegen/build.go +++ b/codegen/build.go @@ -27,6 +27,7 @@ type ModelBuild struct { PackageName string Imports Imports Models []Model + Enums []Enum } // Create a list of models that need to be generated @@ -45,6 +46,7 @@ func Models(schema *schema.Schema, userTypes map[string]string, destDir string) return &ModelBuild{ PackageName: filepath.Base(destDir), Models: models, + Enums: buildEnums(namedTypes, schema), Imports: buildImports(namedTypes, destDir), } } @@ -63,6 +65,7 @@ func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (* objects := buildObjects(namedTypes, schema, prog, imports) inputs := buildInputs(namedTypes, schema, prog, imports) + buildEnums(namedTypes, schema) b := &Build{ PackageName: filepath.Base(destDir), diff --git a/codegen/enum.go b/codegen/enum.go new file mode 100644 index 00000000000..e62fd2b175a --- /dev/null +++ b/codegen/enum.go @@ -0,0 +1,12 @@ +package codegen + +type Enum struct { + *NamedType + + Values []EnumValue +} + +type EnumValue struct { + Name string + Description string +} diff --git a/codegen/enum_build.go b/codegen/enum_build.go new file mode 100644 index 00000000000..75976da78ae --- /dev/null +++ b/codegen/enum_build.go @@ -0,0 +1,37 @@ +package codegen + +import ( + "sort" + "strings" + + "github.com/vektah/gqlgen/neelance/schema" +) + +func buildEnums(types NamedTypes, s *schema.Schema) []Enum { + var enums []Enum + + for _, typ := range s.Types { + if strings.HasPrefix(typ.TypeName(), "__") { + continue + } + if e, ok := typ.(*schema.Enum); ok { + var values []EnumValue + for _, v := range e.Values { + values = append(values, EnumValue{v.Name, v.Desc}) + } + + enum := Enum{ + NamedType: types[e.TypeName()], + Values: values, + } + enum.GoType = ucFirst(enum.GQLType) + enums = append(enums, enum) + } + } + + sort.Slice(enums, func(i, j int) bool { + return strings.Compare(enums[i].GQLType, enums[j].GQLType) == -1 + }) + + return enums +} diff --git a/codegen/templates/data.go b/codegen/templates/data.go index 339ee0c5c1a..8afd4ca8ad8 100644 --- a/codegen/templates/data.go +++ b/codegen/templates/data.go @@ -6,6 +6,6 @@ var data = map[string]string{ "generated.gotpl": "// This file was generated by github.com/vektah/gqlgen, DO NOT EDIT\n\npackage {{ .PackageName }}\n\nimport (\n{{- range $import := .Imports }}\n\t{{- $import.Write }}\n{{ end }}\n)\n\nfunc MakeExecutableSchema(resolvers Resolvers) graphql.ExecutableSchema {\n\treturn &executableSchema{resolvers: resolvers}\n}\n\ntype Resolvers interface {\n{{- range $object := .Objects -}}\n\t{{ range $field := $object.Fields -}}\n\t\t{{ $field.ResolverDeclaration }}\n\t{{ end }}\n{{- end }}\n}\n\ntype executableSchema struct {\n\tresolvers Resolvers\n}\n\nfunc (e *executableSchema) Schema() *schema.Schema {\n\treturn parsedSchema\n}\n\nfunc (e *executableSchema) Query(ctx context.Context, op *query.Operation) *graphql.Response {\n\t{{- if .QueryRoot }}\n\t\tec := executionContext{graphql.GetRequestContext(ctx), e.resolvers}\n\n\t\tdata := ec._{{.QueryRoot.GQLType}}(ctx, op.Selections)\n\t\tvar buf bytes.Buffer\n\t\tdata.MarshalGQL(&buf)\n\n\t\treturn &graphql.Response{\n\t\t\tData: buf.Bytes(),\n\t\t\tErrors: ec.Errors,\n\t\t}\n\t{{- else }}\n\t\treturn &graphql.Response{Errors: []*errors.QueryError{ {Message: \"queries are not supported\"} }}\n\t{{- end }}\n}\n\nfunc (e *executableSchema) Mutation(ctx context.Context, op *query.Operation) *graphql.Response {\n\t{{- if .MutationRoot }}\n\t\tec := executionContext{graphql.GetRequestContext(ctx), e.resolvers}\n\n\t\tdata := ec._{{.MutationRoot.GQLType}}(ctx, op.Selections)\n\t\tvar buf bytes.Buffer\n\t\tdata.MarshalGQL(&buf)\n\n\t\treturn &graphql.Response{\n\t\t\tData: buf.Bytes(),\n\t\t\tErrors: ec.Errors,\n\t\t}\n\t{{- else }}\n\t\treturn &graphql.Response{Errors: []*errors.QueryError{ {Message: \"mutations are not supported\"} }}\n\t{{- end }}\n}\n\nfunc (e *executableSchema) Subscription(ctx context.Context, op *query.Operation) func() *graphql.Response {\n\t{{- if .SubscriptionRoot }}\n\t\tec := executionContext{graphql.GetRequestContext(ctx), e.resolvers}\n\n\t\tnext := ec._{{.SubscriptionRoot.GQLType}}(ctx, op.Selections)\n\t\tif ec.Errors != nil {\n\t\t\treturn graphql.OneShot(&graphql.Response{Data: []byte(\"null\"), Errors: ec.Errors})\n\t\t}\n\n\t\tvar buf bytes.Buffer\n\t\treturn func() *graphql.Response {\n\t\t\tbuf.Reset()\n\t\t\tdata := next()\n\t\t\tif data == nil {\n\t\t\t\treturn nil\n\t\t\t}\n\t\t\tdata.MarshalGQL(&buf)\n\n\t\t\terrs := ec.Errors\n\t\t\tec.Errors = nil\n\t\t\treturn &graphql.Response{\n\t\t\t\tData: buf.Bytes(),\n\t\t\t\tErrors: errs,\n\t\t\t}\n\t\t}\n\t{{- else }}\n\t\treturn graphql.OneShot(&graphql.Response{Errors: []*errors.QueryError{ {Message: \"subscriptions are not supported\"} }})\n\t{{- end }}\n}\n\ntype executionContext struct {\n\t*graphql.RequestContext\n\n\tresolvers Resolvers\n}\n\n{{- range $object := .Objects }}\n\t{{ template \"object.gotpl\" $object }}\n\n\t{{- range $field := $object.Fields }}\n\t\t{{ template \"field.gotpl\" $field }}\n\t{{ end }}\n{{- end}}\n\n{{- range $interface := .Interfaces }}\n\t{{ template \"interface.gotpl\" $interface }}\n{{- end }}\n\n{{- range $input := .Inputs }}\n\t{{ template \"input.gotpl\" $input }}\n{{- end }}\n\nvar parsedSchema = schema.MustParse({{.SchemaRaw|quote}})\n\nfunc (ec *executionContext) introspectSchema() *introspection.Schema {\n\treturn introspection.WrapSchema(parsedSchema)\n}\n\nfunc (ec *executionContext) introspectType(name string) *introspection.Type {\n\tt := parsedSchema.Resolve(name)\n\tif t == nil {\n\t\treturn nil\n\t}\n\treturn introspection.WrapType(t)\n}\n", "input.gotpl": "\t{{- if .IsMarshaled }}\n\tfunc Unmarshal{{ .GQLType }}(v interface{}) ({{.FullName}}, error) {\n\t\tvar it {{.FullName}}\n\n\t\tfor k, v := range v.(map[string]interface{}) {\n\t\t\tswitch k {\n\t\t\t{{- range $field := .Fields }}\n\t\t\tcase {{$field.GQLName|quote}}:\n\t\t\t\tvar err error\n\t\t\t\t{{ $field.Unmarshal (print \"it.\" $field.GoVarName) \"v\" }}\n\t\t\t\tif err != nil {\n\t\t\t\t\treturn it, err\n\t\t\t\t}\n\t\t\t{{- end }}\n\t\t\t}\n\t\t}\n\n\t\treturn it, nil\n\t}\n\t{{- end }}\n", "interface.gotpl": "{{- $interface := . }}\n\nfunc (ec *executionContext) _{{$interface.GQLType}}(ctx context.Context, sel []query.Selection, obj *{{$interface.FullName}}) graphql.Marshaler {\n\tswitch obj := (*obj).(type) {\n\tcase nil:\n\t\treturn graphql.Null\n\t{{- range $implementor := $interface.Implementors }}\n\t\t{{- if $implementor.ValueReceiver }}\n\t\t\tcase {{$implementor.FullName}}:\n\t\t\t\treturn ec._{{$implementor.GQLType}}(ctx, sel, &obj)\n\t\t{{- end}}\n\t\tcase *{{$implementor.FullName}}:\n\t\t\treturn ec._{{$implementor.GQLType}}(ctx, sel, obj)\n\t{{- end }}\n\tdefault:\n\t\tpanic(fmt.Errorf(\"unexpected type %T\", obj))\n\t}\n}\n", - "models.gotpl": "// This file was generated by github.com/vektah/gqlgen, DO NOT EDIT\n\npackage {{ .PackageName }}\n\nimport (\n{{- range $import := .Imports }}\n\t{{- $import.Write }}\n{{ end }}\n)\n\n{{ range $model := .Models }}\n\t{{- if .IsInterface }}\n\t\ttype {{.GoType}} interface {}\n\t{{- else }}\n\t\ttype {{.GoType}} struct {\n\t\t\t{{- range $field := .Fields }}\n\t\t\t\t{{- if $field.GoVarName }}\n\t\t\t\t\t{{ $field.GoVarName }} {{$field.Signature}}\n\t\t\t\t{{- else }}\n\t\t\t\t\t{{ $field.GoFKName }} {{$field.GoFKType}}\n\t\t\t\t{{- end }}\n\t\t\t{{- end }}\n\t\t}\n\t{{- end }}\n{{- end}}\n", + "models.gotpl": "// This file was generated by github.com/vektah/gqlgen, DO NOT EDIT\n\npackage {{ .PackageName }}\n\nimport (\n{{- range $import := .Imports }}\n\t{{- $import.Write }}\n{{ end }}\n)\n\n{{ range $model := .Models }}\n\t{{- if .IsInterface }}\n\t\ttype {{.GoType}} interface {}\n\t{{- else }}\n\t\ttype {{.GoType}} struct {\n\t\t\t{{- range $field := .Fields }}\n\t\t\t\t{{- if $field.GoVarName }}\n\t\t\t\t\t{{ $field.GoVarName }} {{$field.Signature}}\n\t\t\t\t{{- else }}\n\t\t\t\t\t{{ $field.GoFKName }} {{$field.GoFKType}}\n\t\t\t\t{{- end }}\n\t\t\t{{- end }}\n\t\t}\n\t{{- end }}\n{{- end}}\n\n{{ range $enum := .Enums }}\n\ttype {{.GoType}} string\n\tconst (\n\t{{ range $value := .Values }}\n\t\t{{$enum.GoType}}{{ .Name|toCamel }} {{$enum.GoType}} = {{.Name|quote}} {{with .Description}} // {{.}} {{end}}\n\t{{- end }}\n\t)\n\n\tfunc (e {{.GoType}}) IsValid() bool {\n\t\tswitch e {\n\t\tcase {{ range $index, $element := .Values}}{{if $index}},{{end}}{{ $enum.GoType }}{{ $element.Name|toCamel }}{{end}}:\n\t\t\treturn true\n\t\t}\n\t\treturn false\n\t}\n\n\tfunc (e {{.GoType}}) String() string {\n\t\treturn string(e)\n\t}\n\n\tfunc (e *{{.GoType}}) UnmarshalGQL(v interface{}) error {\n\t\tstr, ok := v.(string)\n\t\tif !ok {\n\t\t\treturn fmt.Errorf(\"enums must be strings\")\n\t\t}\n\n\t\t*e = {{.GoType}}(str)\n\t\tif !e.IsValid() {\n\t\t\treturn fmt.Errorf(\"%s is not a valid {{.GQLType}}\", str)\n\t\t}\n\t\treturn nil\n\t}\n\n\tfunc (e {{.GoType}}) MarshalGQL(w io.Writer) {\n\t\tfmt.Fprint(w, strconv.Quote(e.String()))\n\t}\n\n{{- end }}\n", "object.gotpl": "{{ $object := . }}\n\nvar {{ $object.GQLType|lcFirst}}Implementors = {{$object.Implementors}}\n\n// nolint: gocyclo, errcheck, gas, goconst\n{{- if .Stream }}\nfunc (ec *executionContext) _{{$object.GQLType}}(ctx context.Context, sel []query.Selection) func() graphql.Marshaler {\n\tfields := graphql.CollectFields(ec.Doc, sel, {{$object.GQLType|lcFirst}}Implementors, ec.Variables)\n\n\tif len(fields) != 1 {\n\t\tec.Errorf(\"must subscribe to exactly one stream\")\n\t\treturn nil\n\t}\n\n\tswitch fields[0].Name {\n\t{{- range $field := $object.Fields }}\n\tcase \"{{$field.GQLName}}\":\n\t\treturn ec._{{$object.GQLType}}_{{$field.GQLName}}(ctx, fields[0])\n\t{{- end }}\n\tdefault:\n\t\tpanic(\"unknown field \" + strconv.Quote(fields[0].Name))\n\t}\n}\n{{- else }}\nfunc (ec *executionContext) _{{$object.GQLType}}(ctx context.Context, sel []query.Selection{{if not $object.Root}}, obj *{{$object.FullName}} {{end}}) graphql.Marshaler {\n\tfields := graphql.CollectFields(ec.Doc, sel, {{$object.GQLType|lcFirst}}Implementors, ec.Variables)\n\tout := graphql.NewOrderedMap(len(fields))\n\tfor i, field := range fields {\n\t\tout.Keys[i] = field.Alias\n\n\t\tswitch field.Name {\n\t\tcase \"__typename\":\n\t\t\tout.Values[i] = graphql.MarshalString({{$object.GQLType|quote}})\n\t\t{{- range $field := $object.Fields }}\n\t\tcase \"{{$field.GQLName}}\":\n\t\t\tout.Values[i] = ec._{{$object.GQLType}}_{{$field.GQLName}}(ctx, field{{if not $object.Root}}, obj{{end}})\n\t\t{{- end }}\n\t\tdefault:\n\t\t\tpanic(\"unknown field \" + strconv.Quote(field.Name))\n\t\t}\n\t}\n\n\treturn out\n}\n{{- end }}\n", } diff --git a/codegen/templates/models.gotpl b/codegen/templates/models.gotpl index 608c83791eb..8b09e0b084c 100644 --- a/codegen/templates/models.gotpl +++ b/codegen/templates/models.gotpl @@ -23,3 +23,42 @@ import ( } {{- end }} {{- end}} + +{{ range $enum := .Enums }} + type {{.GoType}} string + const ( + {{ range $value := .Values }} + {{$enum.GoType}}{{ .Name|toCamel }} {{$enum.GoType}} = {{.Name|quote}} {{with .Description}} // {{.}} {{end}} + {{- end }} + ) + + func (e {{.GoType}}) IsValid() bool { + switch e { + case {{ range $index, $element := .Values}}{{if $index}},{{end}}{{ $enum.GoType }}{{ $element.Name|toCamel }}{{end}}: + return true + } + return false + } + + func (e {{.GoType}}) String() string { + return string(e) + } + + func (e *{{.GoType}}) UnmarshalGQL(v interface{}) error { + str, ok := v.(string) + if !ok { + return fmt.Errorf("enums must be strings") + } + + *e = {{.GoType}}(str) + if !e.IsValid() { + return fmt.Errorf("%s is not a valid {{.GQLType}}", str) + } + return nil + } + + func (e {{.GoType}}) MarshalGQL(w io.Writer) { + fmt.Fprint(w, strconv.Quote(e.String())) + } + +{{- end }} diff --git a/codegen/templates/templates.go b/codegen/templates/templates.go index ed9c1c7577b..875c031f5f1 100644 --- a/codegen/templates/templates.go +++ b/codegen/templates/templates.go @@ -16,6 +16,7 @@ func Run(name string, tpldata interface{}) (*bytes.Buffer, error) { "ucFirst": ucFirst, "lcFirst": lcFirst, "quote": strconv.Quote, + "toCamel": toCamel, "dump": dump, }) @@ -54,6 +55,31 @@ func lcFirst(s string) string { return string(r) } +func isDelimiter(c rune) bool { + return c == '-' || c == '_' || unicode.IsSpace(c) +} + +func toCamel(s string) string { + buffer := make([]rune, 0, len(s)) + upper := true + + for _, c := range s { + if isDelimiter(c) { + upper = true + continue + } + + if upper { + buffer = append(buffer, unicode.ToUpper(c)) + } else { + buffer = append(buffer, unicode.ToLower(c)) + } + upper = false + } + + return string(buffer) +} + func dump(val interface{}) string { switch val := val.(type) { case int: diff --git a/example/starwars/generated.go b/example/starwars/generated.go index 57337edd34f..cc4c9cfd11b 100644 --- a/example/starwars/generated.go +++ b/example/starwars/generated.go @@ -31,10 +31,10 @@ type Resolvers interface { Human_friendsConnection(ctx context.Context, obj *Human, first *int, after *string) (FriendsConnection, error) Human_starships(ctx context.Context, obj *Human) ([]Starship, error) - Mutation_createReview(ctx context.Context, episode string, review Review) (*Review, error) + Mutation_createReview(ctx context.Context, episode Episode, review Review) (*Review, error) - Query_hero(ctx context.Context, episode string) (Character, error) - Query_reviews(ctx context.Context, episode string, since *time.Time) ([]Review, error) + Query_hero(ctx context.Context, episode Episode) (Character, error) + Query_reviews(ctx context.Context, episode Episode, since *time.Time) ([]Review, error) Query_search(ctx context.Context, text string) ([]SearchResult, error) Query_character(ctx context.Context, id string) (Character, error) Query_droid(ctx context.Context, id string) (*Droid, error) @@ -202,7 +202,7 @@ func (ec *executionContext) _Droid_appearsIn(ctx context.Context, field graphql. res := obj.AppearsIn arr1 := graphql.Array{} for idx1 := range res { - arr1 = append(arr1, func() graphql.Marshaler { return graphql.MarshalString(res[idx1]) }()) + arr1 = append(arr1, func() graphql.Marshaler { return res[idx1] }()) } return arr1 } @@ -377,10 +377,10 @@ func (ec *executionContext) _Human_name(ctx context.Context, field graphql.Colle } func (ec *executionContext) _Human_height(ctx context.Context, field graphql.CollectedField, obj *Human) graphql.Marshaler { - var arg0 string + var arg0 LengthUnit if tmp, ok := field.Args["unit"]; ok { var err error - arg0, err = graphql.UnmarshalString(tmp) + err = (&arg0).UnmarshalGQL(tmp) if err != nil { ec.Error(err) return graphql.Null @@ -388,7 +388,7 @@ func (ec *executionContext) _Human_height(ctx context.Context, field graphql.Col } else { var tmp interface{} = "METER" var err error - arg0, err = graphql.UnmarshalString(tmp) + err = (&arg0).UnmarshalGQL(tmp) if err != nil { ec.Error(err) return graphql.Null @@ -478,7 +478,7 @@ func (ec *executionContext) _Human_appearsIn(ctx context.Context, field graphql. res := obj.AppearsIn arr1 := graphql.Array{} for idx1 := range res { - arr1 = append(arr1, func() graphql.Marshaler { return graphql.MarshalString(res[idx1]) }()) + arr1 = append(arr1, func() graphql.Marshaler { return res[idx1] }()) } return arr1 } @@ -529,10 +529,10 @@ func (ec *executionContext) _Mutation(ctx context.Context, sel []query.Selection } func (ec *executionContext) _Mutation_createReview(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { - var arg0 string + var arg0 Episode if tmp, ok := field.Args["episode"]; ok { var err error - arg0, err = graphql.UnmarshalString(tmp) + err = (&arg0).UnmarshalGQL(tmp) if err != nil { ec.Error(err) return graphql.Null @@ -639,10 +639,10 @@ func (ec *executionContext) _Query(ctx context.Context, sel []query.Selection) g } func (ec *executionContext) _Query_hero(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { - var arg0 string + var arg0 Episode if tmp, ok := field.Args["episode"]; ok { var err error - arg0, err = graphql.UnmarshalString(tmp) + err = (&arg0).UnmarshalGQL(tmp) if err != nil { ec.Error(err) return graphql.Null @@ -650,7 +650,7 @@ func (ec *executionContext) _Query_hero(ctx context.Context, field graphql.Colle } else { var tmp interface{} = "NEWHOPE" var err error - arg0, err = graphql.UnmarshalString(tmp) + err = (&arg0).UnmarshalGQL(tmp) if err != nil { ec.Error(err) return graphql.Null @@ -676,10 +676,10 @@ func (ec *executionContext) _Query_hero(ctx context.Context, field graphql.Colle } func (ec *executionContext) _Query_reviews(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { - var arg0 string + var arg0 Episode if tmp, ok := field.Args["episode"]; ok { var err error - arg0, err = graphql.UnmarshalString(tmp) + err = (&arg0).UnmarshalGQL(tmp) if err != nil { ec.Error(err) return graphql.Null @@ -982,10 +982,10 @@ func (ec *executionContext) _Starship_name(ctx context.Context, field graphql.Co } func (ec *executionContext) _Starship_length(ctx context.Context, field graphql.CollectedField, obj *Starship) graphql.Marshaler { - var arg0 string + var arg0 LengthUnit if tmp, ok := field.Args["unit"]; ok { var err error - arg0, err = graphql.UnmarshalString(tmp) + err = (&arg0).UnmarshalGQL(tmp) if err != nil { ec.Error(err) return graphql.Null @@ -993,7 +993,7 @@ func (ec *executionContext) _Starship_length(ctx context.Context, field graphql. } else { var tmp interface{} = "METER" var err error - arg0, err = graphql.UnmarshalString(tmp) + err = (&arg0).UnmarshalGQL(tmp) if err != nil { ec.Error(err) return graphql.Null diff --git a/example/starwars/model.go b/example/starwars/model.go index 62de3c8115b..c06bcbbbfee 100644 --- a/example/starwars/model.go +++ b/example/starwars/model.go @@ -13,7 +13,7 @@ type CharacterFields struct { ID string Name string FriendIds []string - AppearsIn []string + AppearsIn []Episode } type Human struct { @@ -23,7 +23,7 @@ type Human struct { Mass float64 } -func (h *Human) Height(unit string) float64 { +func (h *Human) Height(unit LengthUnit) float64 { switch unit { case "METER", "": return h.heightMeters @@ -41,7 +41,7 @@ type Starship struct { lengthMeters float64 } -func (s *Starship) Length(unit string) float64 { +func (s *Starship) Length(unit LengthUnit) float64 { switch unit { case "METER", "": return s.lengthMeters diff --git a/example/starwars/models_gen.go b/example/starwars/models_gen.go index b8137a4dd63..e9cbc8ab97a 100644 --- a/example/starwars/models_gen.go +++ b/example/starwars/models_gen.go @@ -2,6 +2,12 @@ package starwars +import ( + fmt "fmt" + io "io" + strconv "strconv" +) + type Character interface{} type PageInfo struct { StartCursor string @@ -9,3 +15,76 @@ type PageInfo struct { HasNextPage bool } type SearchResult interface{} + +type Episode string + +const ( + EpisodeNewhope Episode = "NEWHOPE" // Star Wars Episode IV: A New Hope, released in 1977. + EpisodeEmpire Episode = "EMPIRE" // Star Wars Episode V: The Empire Strikes Back, released in 1980. + EpisodeJedi Episode = "JEDI" // Star Wars Episode VI: Return of the Jedi, released in 1983. +) + +func (e Episode) IsValid() bool { + switch e { + case EpisodeNewhope, EpisodeEmpire, EpisodeJedi: + return true + } + return false +} + +func (e Episode) String() string { + return string(e) +} + +func (e *Episode) UnmarshalGQL(v interface{}) error { + str, ok := v.(string) + if !ok { + return fmt.Errorf("enums must be strings") + } + + *e = Episode(str) + if !e.IsValid() { + return fmt.Errorf("%s is not a valid Episode", str) + } + return nil +} + +func (e Episode) MarshalGQL(w io.Writer) { + fmt.Fprint(w, strconv.Quote(e.String())) +} + +type LengthUnit string + +const ( + LengthUnitMeter LengthUnit = "METER" // The standard unit around the world + LengthUnitFoot LengthUnit = "FOOT" // Primarily used in the United States +) + +func (e LengthUnit) IsValid() bool { + switch e { + case LengthUnitMeter, LengthUnitFoot: + return true + } + return false +} + +func (e LengthUnit) String() string { + return string(e) +} + +func (e *LengthUnit) UnmarshalGQL(v interface{}) error { + str, ok := v.(string) + if !ok { + return fmt.Errorf("enums must be strings") + } + + *e = LengthUnit(str) + if !e.IsValid() { + return fmt.Errorf("%s is not a valid LengthUnit", str) + } + return nil +} + +func (e LengthUnit) MarshalGQL(w io.Writer) { + fmt.Fprint(w, strconv.Quote(e.String())) +} diff --git a/example/starwars/resolvers.go b/example/starwars/resolvers.go index 03b289bf504..fe18d7598bb 100644 --- a/example/starwars/resolvers.go +++ b/example/starwars/resolvers.go @@ -12,7 +12,7 @@ type Resolver struct { humans map[string]Human droid map[string]Droid starships map[string]Starship - reviews map[string][]Review + reviews map[Episode][]Review } func (r *Resolver) resolveCharacters(ctx context.Context, ids []string) ([]Character, error) { @@ -78,21 +78,21 @@ func (r *Resolver) FriendsConnection_friends(ctx context.Context, it *FriendsCon return r.resolveCharacters(ctx, it.ids) } -func (r *Resolver) Mutation_createReview(ctx context.Context, episode string, review Review) (*Review, error) { +func (r *Resolver) Mutation_createReview(ctx context.Context, episode Episode, review Review) (*Review, error) { review.Time = time.Now() time.Sleep(1 * time.Second) r.reviews[episode] = append(r.reviews[episode], review) return &review, nil } -func (r *Resolver) Query_hero(ctx context.Context, episode string) (Character, error) { - if episode == "EMPIRE" { +func (r *Resolver) Query_hero(ctx context.Context, episode Episode) (Character, error) { + if episode == EpisodeEmpire { return r.humans["1000"], nil } return r.droid["2001"], nil } -func (r *Resolver) Query_reviews(ctx context.Context, episode string, since *time.Time) ([]Review, error) { +func (r *Resolver) Query_reviews(ctx context.Context, episode Episode, since *time.Time) ([]Review, error) { if since == nil { return r.reviews[episode], nil } @@ -162,7 +162,7 @@ func NewResolver() *Resolver { ID: "1000", Name: "Luke Skywalker", FriendIds: []string{"1002", "1003", "2000", "2001"}, - AppearsIn: []string{"NEWHOPE", "EMPIRE", "JEDI"}, + AppearsIn: []Episode{EpisodeNewhope, EpisodeEmpire, EpisodeJedi}, }, heightMeters: 1.72, Mass: 77, @@ -173,7 +173,7 @@ func NewResolver() *Resolver { ID: "1001", Name: "Darth Vader", FriendIds: []string{"1004"}, - AppearsIn: []string{"NEWHOPE", "EMPIRE", "JEDI"}, + AppearsIn: []Episode{EpisodeNewhope, EpisodeEmpire, EpisodeJedi}, }, heightMeters: 2.02, Mass: 136, @@ -184,7 +184,7 @@ func NewResolver() *Resolver { ID: "1002", Name: "Han Solo", FriendIds: []string{"1000", "1003", "2001"}, - AppearsIn: []string{"NEWHOPE", "EMPIRE", "JEDI"}, + AppearsIn: []Episode{EpisodeNewhope, EpisodeEmpire, EpisodeJedi}, }, heightMeters: 1.8, Mass: 80, @@ -195,7 +195,7 @@ func NewResolver() *Resolver { ID: "1003", Name: "Leia Organa", FriendIds: []string{"1000", "1002", "2000", "2001"}, - AppearsIn: []string{"NEWHOPE", "EMPIRE", "JEDI"}, + AppearsIn: []Episode{EpisodeNewhope, EpisodeEmpire, EpisodeJedi}, }, heightMeters: 1.5, Mass: 49, @@ -205,7 +205,7 @@ func NewResolver() *Resolver { ID: "1004", Name: "Wilhuff Tarkin", FriendIds: []string{"1001"}, - AppearsIn: []string{"NEWHOPE"}, + AppearsIn: []Episode{EpisodeNewhope}, }, heightMeters: 1.8, Mass: 0, @@ -218,7 +218,7 @@ func NewResolver() *Resolver { ID: "2000", Name: "C-3PO", FriendIds: []string{"1000", "1002", "1003", "2001"}, - AppearsIn: []string{"NEWHOPE", "EMPIRE", "JEDI"}, + AppearsIn: []Episode{EpisodeNewhope, EpisodeEmpire, EpisodeJedi}, }, PrimaryFunction: "Protocol", }, @@ -227,7 +227,7 @@ func NewResolver() *Resolver { ID: "2001", Name: "R2-D2", FriendIds: []string{"1000", "1002", "1003"}, - AppearsIn: []string{"NEWHOPE", "EMPIRE", "JEDI"}, + AppearsIn: []Episode{EpisodeNewhope, EpisodeEmpire, EpisodeJedi}, }, PrimaryFunction: "Astromech", }, @@ -280,7 +280,7 @@ func NewResolver() *Resolver { }, } - r.reviews = map[string][]Review{} + r.reviews = map[Episode][]Review{} return &r } diff --git a/example/starwars/starwars_test.go b/example/starwars/starwars_test.go index 66d8dac4bcf..9f2912c7a33 100644 --- a/example/starwars/starwars_test.go +++ b/example/starwars/starwars_test.go @@ -206,6 +206,18 @@ func TestStarwars(t *testing.T) { require.Len(t, resp.Starship.History[0], 2) }) + t.Run("invalid enums in variables", func(t *testing.T) { + var resp struct{} + + err := c.Post(`mutation($episode: Episode!) { + createReview(episode: $episode, review:{stars:1, commentary:"Blah blah"}) { + time + } + }`, &resp, client.Var("episode", "INVALID")) + + require.EqualError(t, err, "errors: [graphql: INVALID is not a valid Episode]") + }) + t.Run("introspection", func(t *testing.T) { // Make sure we can run the graphiql introspection query without errors var resp interface{} diff --git a/main.go b/main.go index 695a37e60c5..28fbe4c63d6 100644 --- a/main.go +++ b/main.go @@ -72,6 +72,10 @@ func main() { for _, model := range modelsBuild.Models { types[model.GQLType] = pkgName + "." + model.GoType } + + for _, enum := range modelsBuild.Enums { + types[enum.GQLType] = pkgName + "." + enum.GoType + } } build, err := codegen.Bind(schema, types, dirName())