Skip to content

Commit

Permalink
Add support for custom enums
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Apr 1, 2018
1 parent 74ac827 commit 61a34a7
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 22 deletions.
1 change: 0 additions & 1 deletion codegen/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (*

objects := buildObjects(namedTypes, schema, prog, imports)
inputs := buildInputs(namedTypes, schema, prog, imports)
buildEnums(namedTypes, schema)

b := &Build{
PackageName: filepath.Base(destDir),
Expand Down
27 changes: 14 additions & 13 deletions codegen/enum_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,23 @@ func buildEnums(types NamedTypes, s *schema.Schema) []Enum {
var enums []Enum

for _, typ := range s.Types {
if strings.HasPrefix(typ.TypeName(), "__") {
namedType := types[typ.TypeName()]
e, isEnum := typ.(*schema.Enum)
if !isEnum || strings.HasPrefix(typ.TypeName(), "__") || namedType.IsUserDefined {
continue
}
if e, ok := typ.(*schema.Enum); ok {
var values []EnumValue
for _, v := range e.Values {
values = append(values, EnumValue{v.Name, v.Desc})
}

enum := Enum{
NamedType: types[e.TypeName()],
Values: values,
}
enum.GoType = ucFirst(enum.GQLType)
enums = append(enums, enum)

var values []EnumValue
for _, v := range e.Values {
values = append(values, EnumValue{v.Name, v.Desc})
}

enum := Enum{
NamedType: namedType,
Values: values,
}
enum.GoType = ucFirst(enum.GQLType)
enums = append(enums, enum)
}

sort.Slice(enums, func(i, j int) bool {
Expand Down
6 changes: 3 additions & 3 deletions codegen/models_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ func buildModels(types NamedTypes, s *schema.Schema, prog *loader.Program) []Mod
switch typ := typ.(type) {
case *schema.Object:
obj := buildObject(types, typ, s)
if obj.Root || obj.GoType != "" {
if obj.Root || obj.IsUserDefined {
continue
}
model = obj2Model(s, obj)
case *schema.InputObject:
obj := buildInput(types, typ)
if obj.GoType != "" {
if obj.IsUserDefined {
continue
}
model = obj2Model(s, obj)
case *schema.Interface, *schema.Union:
intf := buildInterface(types, typ, prog)
if intf.GoType != "" {
if intf.IsUserDefined {
continue
}
model = int2Model(intf)
Expand Down
7 changes: 4 additions & 3 deletions codegen/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ type NamedType struct {
}

type Ref struct {
GoType string // Name of the go type
Package string // the package the go type lives in
Import *Import // the resolved import with alias
GoType string // Name of the go type
Package string // the package the go type lives in
Import *Import // the resolved import with alias
IsUserDefined bool // does the type exist in the typemap
}

type Type struct {
Expand Down
1 change: 1 addition & 0 deletions codegen/type_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func buildNamedTypes(s *schema.Schema, userTypes map[string]string) NamedTypes {
t := namedTypeFromSchema(schemaType)

userType := userTypes[t.GQLType]
t.IsUserDefined = userType != ""
if userType == "" && t.IsScalar {
userType = "github.com/vektah/gqlgen/graphql.String"
}
Expand Down
14 changes: 14 additions & 0 deletions example/scalars/generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ func (ec *executionContext) _User(ctx context.Context, sel []query.Selection, ob
out.Values[i] = ec._User_customResolver(ctx, field, obj)
case "address":
out.Values[i] = ec._User_address(ctx, field, obj)
case "tier":
out.Values[i] = ec._User_tier(ctx, field, obj)
default:
panic("unknown field " + strconv.Quote(field.Name))
}
Expand Down Expand Up @@ -322,6 +324,11 @@ func (ec *executionContext) _User_address(ctx context.Context, field graphql.Col
return ec._Address(ctx, field.Selections, &res)
}

func (ec *executionContext) _User_tier(ctx context.Context, field graphql.CollectedField, obj *User) graphql.Marshaler {
res := obj.Tier
return res
}

var __DirectiveImplementors = []string{"__Directive"}

// nolint: gocyclo, errcheck, gas, goconst
Expand Down Expand Up @@ -881,6 +888,7 @@ type User {
primitiveResolver: String!
customResolver: Point!
address: Address
tier: Tier
}
type Address {
Expand All @@ -894,6 +902,12 @@ input SearchArgs {
isBanned: Boolean
}
enum Tier {
A
B
C
}
scalar Timestamp
scalar Point
`)
59 changes: 59 additions & 0 deletions example/scalars/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type User struct {
Created time.Time // direct binding to builtin types with external Marshal/Unmarshal methods
IsBanned Banned // aliased primitive
Address Address
Tier Tier
}

// Point is serialized as a simple array, eg [1, 2]
Expand Down Expand Up @@ -95,3 +96,61 @@ type SearchArgs struct {
CreatedAfter *time.Time
IsBanned Banned
}

// A custom enum that uses integers to represent the values in memory but serialize as string for graphql
type Tier uint

const (
TierA Tier = iota
TierB Tier = iota
TierC Tier = iota
)

func TierForStr(str string) (Tier, error) {
switch str {
case "A":
return TierA, nil
case "B":
return TierB, nil
case "C":
return TierC, nil
default:
return 0, fmt.Errorf("%s is not a valid Tier", str)
}
}

func (e Tier) IsValid() bool {
switch e {
case TierA, TierB, TierC:
return true
}
return false
}

func (e Tier) String() string {
switch e {
case TierA:
return "A"
case TierB:
return "B"
case TierC:
return "C"
default:
panic("invalid enum value")
}
}

func (e *Tier) UnmarshalGQL(v interface{}) error {
str, ok := v.(string)
if !ok {
return fmt.Errorf("enums must be strings")
}

var err error
*e, err = TierForStr(str)
return err
}

func (e Tier) MarshalGQL(w io.Writer) {
fmt.Fprint(w, strconv.Quote(e.String()))
}
3 changes: 3 additions & 0 deletions example/scalars/resolvers.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ func (r *Resolver) Query_user(ctx context.Context, id external.ObjectID) (*User,
Name: fmt.Sprintf("Test User %d", id),
Created: time.Now(),
Address: Address{ID: 1, Location: &Point{1, 2}},
Tier: TierC,
}, nil
}

Expand All @@ -39,12 +40,14 @@ func (r *Resolver) Query_search(ctx context.Context, input SearchArgs) ([]User,
Name: "Test User 1",
Created: created,
Address: Address{ID: 2, Location: &location},
Tier: TierA,
},
{
ID: 2,
Name: "Test User 2",
Created: created,
Address: Address{ID: 1, Location: &location},
Tier: TierC,
},
}, nil
}
Expand Down
4 changes: 3 additions & 1 deletion example/scalars/scalar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type RawUser struct {
Address struct{ Location string }
PrimitiveResolver string
CustomResolver string
Tier string
}

func TestScalars(t *testing.T) {
Expand All @@ -37,12 +38,13 @@ func TestScalars(t *testing.T) {
...UserData
}
}
fragment UserData on User { id name created address { location } }`, &resp)
fragment UserData on User { id name created tier address { location } }`, &resp)

require.Equal(t, "1,2", resp.User.Address.Location)
require.Equal(t, time.Now().Unix(), resp.User.Created)
require.Equal(t, "6,66", resp.Search[0].Address.Location)
require.Equal(t, int64(666), resp.Search[0].Created)
require.Equal(t, "A", resp.Search[0].Tier)
})

t.Run("default search location", func(t *testing.T) {
Expand Down
7 changes: 7 additions & 0 deletions example/scalars/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type User {
primitiveResolver: String!
customResolver: Point!
address: Address
tier: Tier
}

type Address {
Expand All @@ -24,5 +25,11 @@ input SearchArgs {
isBanned: Boolean
}

enum Tier {
A
B
C
}

scalar Timestamp
scalar Point
3 changes: 2 additions & 1 deletion example/scalars/types.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
"Timestamp": "github.com/vektah/gqlgen/example/scalars.Timestamp",
"SearchArgs": "github.com/vektah/gqlgen/example/scalars.SearchArgs",
"Point": "github.com/vektah/gqlgen/example/scalars.Point",
"ID": "github.com/vektah/gqlgen/example/scalars.ID"
"ID": "github.com/vektah/gqlgen/example/scalars.ID",
"Tier": "github.com/vektah/gqlgen/example/scalars.Tier"
}

0 comments on commit 61a34a7

Please sign in to comment.