diff --git a/codegen/build.go b/codegen/build.go index 27383982c69..8f7dc6f11cc 100644 --- a/codegen/build.go +++ b/codegen/build.go @@ -65,7 +65,6 @@ func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (* objects := buildObjects(namedTypes, schema, prog, imports) inputs := buildInputs(namedTypes, schema, prog, imports) - buildEnums(namedTypes, schema) b := &Build{ PackageName: filepath.Base(destDir), diff --git a/codegen/enum_build.go b/codegen/enum_build.go index 75976da78ae..59a342f2ed4 100644 --- a/codegen/enum_build.go +++ b/codegen/enum_build.go @@ -11,22 +11,23 @@ func buildEnums(types NamedTypes, s *schema.Schema) []Enum { var enums []Enum for _, typ := range s.Types { - if strings.HasPrefix(typ.TypeName(), "__") { + namedType := types[typ.TypeName()] + e, isEnum := typ.(*schema.Enum) + if !isEnum || strings.HasPrefix(typ.TypeName(), "__") || namedType.IsUserDefined { continue } - if e, ok := typ.(*schema.Enum); ok { - var values []EnumValue - for _, v := range e.Values { - values = append(values, EnumValue{v.Name, v.Desc}) - } - - enum := Enum{ - NamedType: types[e.TypeName()], - Values: values, - } - enum.GoType = ucFirst(enum.GQLType) - enums = append(enums, enum) + + var values []EnumValue + for _, v := range e.Values { + values = append(values, EnumValue{v.Name, v.Desc}) + } + + enum := Enum{ + NamedType: namedType, + Values: values, } + enum.GoType = ucFirst(enum.GQLType) + enums = append(enums, enum) } sort.Slice(enums, func(i, j int) bool { diff --git a/codegen/models_build.go b/codegen/models_build.go index 278419f6207..d75deebde78 100644 --- a/codegen/models_build.go +++ b/codegen/models_build.go @@ -16,19 +16,19 @@ func buildModels(types NamedTypes, s *schema.Schema, prog *loader.Program) []Mod switch typ := typ.(type) { case *schema.Object: obj := buildObject(types, typ, s) - if obj.Root || obj.GoType != "" { + if obj.Root || obj.IsUserDefined { continue } model = obj2Model(s, obj) case *schema.InputObject: obj := buildInput(types, typ) - if obj.GoType != "" { + if obj.IsUserDefined { continue } model = obj2Model(s, obj) case *schema.Interface, *schema.Union: intf := buildInterface(types, typ, prog) - if intf.GoType != "" { + if intf.IsUserDefined { continue } model = int2Model(intf) diff --git a/codegen/type.go b/codegen/type.go index 7d945ccce56..9b3b799ae5c 100644 --- a/codegen/type.go +++ b/codegen/type.go @@ -17,9 +17,10 @@ type NamedType struct { } type Ref struct { - GoType string // Name of the go type - Package string // the package the go type lives in - Import *Import // the resolved import with alias + GoType string // Name of the go type + Package string // the package the go type lives in + Import *Import // the resolved import with alias + IsUserDefined bool // does the type exist in the typemap } type Type struct { diff --git a/codegen/type_build.go b/codegen/type_build.go index 8e5a14aa7d8..bbc9a64b6bf 100644 --- a/codegen/type_build.go +++ b/codegen/type_build.go @@ -17,6 +17,7 @@ func buildNamedTypes(s *schema.Schema, userTypes map[string]string) NamedTypes { t := namedTypeFromSchema(schemaType) userType := userTypes[t.GQLType] + t.IsUserDefined = userType != "" if userType == "" && t.IsScalar { userType = "github.com/vektah/gqlgen/graphql.String" } diff --git a/example/scalars/generated.go b/example/scalars/generated.go index f15596fb5f5..f521b7700fb 100644 --- a/example/scalars/generated.go +++ b/example/scalars/generated.go @@ -251,6 +251,8 @@ func (ec *executionContext) _User(ctx context.Context, sel []query.Selection, ob out.Values[i] = ec._User_customResolver(ctx, field, obj) case "address": out.Values[i] = ec._User_address(ctx, field, obj) + case "tier": + out.Values[i] = ec._User_tier(ctx, field, obj) default: panic("unknown field " + strconv.Quote(field.Name)) } @@ -322,6 +324,11 @@ func (ec *executionContext) _User_address(ctx context.Context, field graphql.Col return ec._Address(ctx, field.Selections, &res) } +func (ec *executionContext) _User_tier(ctx context.Context, field graphql.CollectedField, obj *User) graphql.Marshaler { + res := obj.Tier + return res +} + var __DirectiveImplementors = []string{"__Directive"} // nolint: gocyclo, errcheck, gas, goconst @@ -881,6 +888,7 @@ type User { primitiveResolver: String! customResolver: Point! address: Address + tier: Tier } type Address { @@ -894,6 +902,12 @@ input SearchArgs { isBanned: Boolean } +enum Tier { + A + B + C +} + scalar Timestamp scalar Point `) diff --git a/example/scalars/model.go b/example/scalars/model.go index ce86b28c26a..f2322fb64c0 100644 --- a/example/scalars/model.go +++ b/example/scalars/model.go @@ -21,6 +21,7 @@ type User struct { Created time.Time // direct binding to builtin types with external Marshal/Unmarshal methods IsBanned Banned // aliased primitive Address Address + Tier Tier } // Point is serialized as a simple array, eg [1, 2] @@ -95,3 +96,61 @@ type SearchArgs struct { CreatedAfter *time.Time IsBanned Banned } + +// A custom enum that uses integers to represent the values in memory but serialize as string for graphql +type Tier uint + +const ( + TierA Tier = iota + TierB Tier = iota + TierC Tier = iota +) + +func TierForStr(str string) (Tier, error) { + switch str { + case "A": + return TierA, nil + case "B": + return TierB, nil + case "C": + return TierC, nil + default: + return 0, fmt.Errorf("%s is not a valid Tier", str) + } +} + +func (e Tier) IsValid() bool { + switch e { + case TierA, TierB, TierC: + return true + } + return false +} + +func (e Tier) String() string { + switch e { + case TierA: + return "A" + case TierB: + return "B" + case TierC: + return "C" + default: + panic("invalid enum value") + } +} + +func (e *Tier) UnmarshalGQL(v interface{}) error { + str, ok := v.(string) + if !ok { + return fmt.Errorf("enums must be strings") + } + + var err error + *e, err = TierForStr(str) + return err +} + +func (e Tier) MarshalGQL(w io.Writer) { + fmt.Fprint(w, strconv.Quote(e.String())) +} diff --git a/example/scalars/resolvers.go b/example/scalars/resolvers.go index 07f80884800..15b8724f765 100644 --- a/example/scalars/resolvers.go +++ b/example/scalars/resolvers.go @@ -19,6 +19,7 @@ func (r *Resolver) Query_user(ctx context.Context, id external.ObjectID) (*User, Name: fmt.Sprintf("Test User %d", id), Created: time.Now(), Address: Address{ID: 1, Location: &Point{1, 2}}, + Tier: TierC, }, nil } @@ -39,12 +40,14 @@ func (r *Resolver) Query_search(ctx context.Context, input SearchArgs) ([]User, Name: "Test User 1", Created: created, Address: Address{ID: 2, Location: &location}, + Tier: TierA, }, { ID: 2, Name: "Test User 2", Created: created, Address: Address{ID: 1, Location: &location}, + Tier: TierC, }, }, nil } diff --git a/example/scalars/scalar_test.go b/example/scalars/scalar_test.go index b26ed9af97b..0899fde5f1a 100644 --- a/example/scalars/scalar_test.go +++ b/example/scalars/scalar_test.go @@ -18,6 +18,7 @@ type RawUser struct { Address struct{ Location string } PrimitiveResolver string CustomResolver string + Tier string } func TestScalars(t *testing.T) { @@ -37,12 +38,13 @@ func TestScalars(t *testing.T) { ...UserData } } - fragment UserData on User { id name created address { location } }`, &resp) + fragment UserData on User { id name created tier address { location } }`, &resp) require.Equal(t, "1,2", resp.User.Address.Location) require.Equal(t, time.Now().Unix(), resp.User.Created) require.Equal(t, "6,66", resp.Search[0].Address.Location) require.Equal(t, int64(666), resp.Search[0].Created) + require.Equal(t, "A", resp.Search[0].Tier) }) t.Run("default search location", func(t *testing.T) { diff --git a/example/scalars/schema.graphql b/example/scalars/schema.graphql index 29c6641792a..48463421e20 100644 --- a/example/scalars/schema.graphql +++ b/example/scalars/schema.graphql @@ -11,6 +11,7 @@ type User { primitiveResolver: String! customResolver: Point! address: Address + tier: Tier } type Address { @@ -24,5 +25,11 @@ input SearchArgs { isBanned: Boolean } +enum Tier { + A + B + C +} + scalar Timestamp scalar Point diff --git a/example/scalars/types.json b/example/scalars/types.json index 8cb05428335..26ac8440b0f 100644 --- a/example/scalars/types.json +++ b/example/scalars/types.json @@ -3,5 +3,6 @@ "Timestamp": "github.com/vektah/gqlgen/example/scalars.Timestamp", "SearchArgs": "github.com/vektah/gqlgen/example/scalars.SearchArgs", "Point": "github.com/vektah/gqlgen/example/scalars.Point", - "ID": "github.com/vektah/gqlgen/example/scalars.ID" + "ID": "github.com/vektah/gqlgen/example/scalars.ID", + "Tier": "github.com/vektah/gqlgen/example/scalars.Tier" }