From 63ee41996e5466e27ea2711923c546eef41c6183 Mon Sep 17 00:00:00 2001 From: Adam Scarr Date: Mon, 7 May 2018 16:21:04 +1000 Subject: [PATCH] Fix vendor normalization When refering to vendored types in fields a type assertion would fail. This PR makes sure that both paths are normalized to not include the vendor directory. --- codegen/type_build.go | 5 -- codegen/util.go | 12 +++- codegen/util_test.go | 14 ++++ test/generated.go | 116 +++++++++++++++++++++++++++++++++ test/resolvers_test.go | 8 +++ test/schema.graphql | 8 +++ test/types.json | 4 +- test/vendor/remote_api/user.go | 5 ++ test/viewer.go | 7 ++ 9 files changed, 172 insertions(+), 7 deletions(-) create mode 100644 codegen/util_test.go create mode 100644 test/vendor/remote_api/user.go create mode 100644 test/viewer.go diff --git a/codegen/type_build.go b/codegen/type_build.go index aea5a425eb3..6ecdbfdf453 100644 --- a/codegen/type_build.go +++ b/codegen/type_build.go @@ -76,11 +76,6 @@ func pkgAndType(name string) (string, string) { return normalizeVendor(strings.Join(parts[:len(parts)-1], ".")), parts[len(parts)-1] } -func normalizeVendor(pkg string) string { - parts := strings.Split(pkg, "/vendor/") - return parts[len(parts)-1] -} - func (n NamedTypes) getType(t common.Type) *Type { var modifiers []string usePtr := true diff --git a/codegen/util.go b/codegen/util.go index c69059a40d1..f727a3e90b1 100644 --- a/codegen/util.go +++ b/codegen/util.go @@ -2,6 +2,7 @@ package codegen import ( "go/types" + "regexp" "strings" "github.com/pkg/errors" @@ -176,7 +177,7 @@ func bindObject(t types.Type, object *Object, imports Imports) error { field.Type.Modifiers = modifiersFromGoType(structField.Type()) field.GoVarName = structField.Name() - switch field.Type.FullSignature() { + switch normalizeVendor(field.Type.FullSignature()) { case normalizeVendor(structField.Type().String()): // everything is fine @@ -215,3 +216,12 @@ func modifiersFromGoType(t types.Type) []string { } } } + +var modsRegex = regexp.MustCompile(`^(\*|\[\])*`) + +func normalizeVendor(pkg string) string { + modifiers := modsRegex.FindAllString(pkg, 1)[0] + pkg = strings.TrimPrefix(pkg, modifiers) + parts := strings.Split(pkg, "/vendor/") + return modifiers + parts[len(parts)-1] +} diff --git a/codegen/util_test.go b/codegen/util_test.go new file mode 100644 index 00000000000..cb41170ddb6 --- /dev/null +++ b/codegen/util_test.go @@ -0,0 +1,14 @@ +package codegen + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNormalizeVendor(t *testing.T) { + require.Equal(t, "bar/baz", normalizeVendor("foo/vendor/bar/baz")) + require.Equal(t, "[]bar/baz", normalizeVendor("[]foo/vendor/bar/baz")) + require.Equal(t, "*bar/baz", normalizeVendor("*foo/vendor/bar/baz")) + require.Equal(t, "*[]*bar/baz", normalizeVendor("*[]*foo/vendor/bar/baz")) +} diff --git a/test/generated.go b/test/generated.go index 06350707304..de21730f4cb 100644 --- a/test/generated.go +++ b/test/generated.go @@ -5,6 +5,7 @@ package test import ( "bytes" context "context" + remote_api "remote_api" strconv "strconv" graphql "github.com/vektah/gqlgen/graphql" @@ -23,6 +24,7 @@ type Resolvers interface { Element_error(ctx context.Context, obj *Element) (bool, error) Query_path(ctx context.Context) ([]Element, error) Query_date(ctx context.Context, filter models.DateFilter) (bool, error) + Query_viewer(ctx context.Context) (*Viewer, error) } type executableSchema struct { @@ -169,6 +171,8 @@ func (ec *executionContext) _Query(ctx context.Context, sel []query.Selection) g out.Values[i] = ec._Query_path(ctx, field) case "date": out.Values[i] = ec._Query_date(ctx, field) + case "viewer": + out.Values[i] = ec._Query_viewer(ctx, field) case "__schema": out.Values[i] = ec._Query___schema(ctx, field) case "__type": @@ -261,6 +265,39 @@ func (ec *executionContext) _Query_date(ctx context.Context, field graphql.Colle }) } +func (ec *executionContext) _Query_viewer(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { + ctx = graphql.WithResolverContext(ctx, &graphql.ResolverContext{ + Object: "Query", + Args: nil, + Field: field, + }) + return graphql.Defer(func() (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + userErr := ec.Recover(ctx, r) + ec.Error(ctx, userErr) + ret = graphql.Null + } + }() + + resTmp, err := ec.ResolverMiddleware(ctx, func(ctx context.Context) (interface{}, error) { + return ec.resolvers.Query_viewer(ctx) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*Viewer) + if res == nil { + return graphql.Null + } + return ec._Viewer(ctx, field.Selections, res) + }) +} + func (ec *executionContext) _Query___schema(ctx context.Context, field graphql.CollectedField) graphql.Marshaler { rctx := graphql.GetResolverContext(ctx) rctx.Object = "Query" @@ -300,6 +337,77 @@ func (ec *executionContext) _Query___type(ctx context.Context, field graphql.Col return ec.___Type(ctx, field.Selections, res) } +var userImplementors = []string{"User"} + +// nolint: gocyclo, errcheck, gas, goconst +func (ec *executionContext) _User(ctx context.Context, sel []query.Selection, obj *remote_api.User) graphql.Marshaler { + fields := graphql.CollectFields(ec.Doc, sel, userImplementors, ec.Variables) + + out := graphql.NewOrderedMap(len(fields)) + for i, field := range fields { + out.Keys[i] = field.Alias + + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("User") + case "name": + out.Values[i] = ec._User_name(ctx, field, obj) + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + + return out +} + +func (ec *executionContext) _User_name(ctx context.Context, field graphql.CollectedField, obj *remote_api.User) graphql.Marshaler { + rctx := graphql.GetResolverContext(ctx) + rctx.Object = "User" + rctx.Args = nil + rctx.Field = field + rctx.PushField(field.Alias) + defer rctx.Pop() + res := obj.Name + return graphql.MarshalString(res) +} + +var viewerImplementors = []string{"Viewer"} + +// nolint: gocyclo, errcheck, gas, goconst +func (ec *executionContext) _Viewer(ctx context.Context, sel []query.Selection, obj *Viewer) graphql.Marshaler { + fields := graphql.CollectFields(ec.Doc, sel, viewerImplementors, ec.Variables) + + out := graphql.NewOrderedMap(len(fields)) + for i, field := range fields { + out.Keys[i] = field.Alias + + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("Viewer") + case "user": + out.Values[i] = ec._Viewer_user(ctx, field, obj) + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + + return out +} + +func (ec *executionContext) _Viewer_user(ctx context.Context, field graphql.CollectedField, obj *Viewer) graphql.Marshaler { + rctx := graphql.GetResolverContext(ctx) + rctx.Object = "Viewer" + rctx.Args = nil + rctx.Field = field + rctx.PushField(field.Alias) + defer rctx.Pop() + res := obj.User + if res == nil { + return graphql.Null + } + return ec._User(ctx, field.Selections, res) +} + var __DirectiveImplementors = []string{"__Directive"} // nolint: gocyclo, errcheck, gas, goconst @@ -1105,8 +1213,16 @@ input DateFilter { op: DATE_FILTER_OP = EQ } +type User { + name: String +} +type Viewer { + user: User +} + type Query { path: [Element] date(filter: DateFilter!): Boolean! + viewer: Viewer } `) diff --git a/test/resolvers_test.go b/test/resolvers_test.go index 05e355a2380..e732ed48bc8 100644 --- a/test/resolvers_test.go +++ b/test/resolvers_test.go @@ -9,6 +9,8 @@ import ( "testing" "time" + "remote_api" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -85,6 +87,12 @@ type testResolvers struct { queryDate func(ctx context.Context, filter models.DateFilter) (bool, error) } +func (r *testResolvers) Query_viewer(ctx context.Context) (*Viewer, error) { + return &Viewer{ + User: &remote_api.User{"Bob"}, + }, nil +} + func (r *testResolvers) Query_date(ctx context.Context, filter models.DateFilter) (bool, error) { return r.queryDate(ctx, filter) } diff --git a/test/schema.graphql b/test/schema.graphql index f5062713131..036abdba148 100644 --- a/test/schema.graphql +++ b/test/schema.graphql @@ -18,7 +18,15 @@ input DateFilter { op: DATE_FILTER_OP = EQ } +type User { + name: String +} +type Viewer { + user: User +} + type Query { path: [Element] date(filter: DateFilter!): Boolean! + viewer: Viewer } diff --git a/test/types.json b/test/types.json index a21ec7a3287..4679c4cd6c4 100644 --- a/test/types.json +++ b/test/types.json @@ -1,3 +1,5 @@ { - "Element": "github.com/vektah/gqlgen/test.Element" + "Element": "github.com/vektah/gqlgen/test.Element", + "Viewer": "github.com/vektah/gqlgen/test.Viewer", + "User": "remote_api.User" } diff --git a/test/vendor/remote_api/user.go b/test/vendor/remote_api/user.go new file mode 100644 index 00000000000..d90cbde895b --- /dev/null +++ b/test/vendor/remote_api/user.go @@ -0,0 +1,5 @@ +package remote_api + +type User struct { + Name string +} diff --git a/test/viewer.go b/test/viewer.go new file mode 100644 index 00000000000..44786f0bf49 --- /dev/null +++ b/test/viewer.go @@ -0,0 +1,7 @@ +package test + +import "remote_api" + +type Viewer struct { + User *remote_api.User +}