Skip to content

Commit

Permalink
Copy existing resolver bodies when regenerating new resolvers
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Feb 3, 2020
1 parent 9e3b399 commit 6ec3650
Show file tree
Hide file tree
Showing 10 changed files with 217 additions and 85 deletions.
14 changes: 7 additions & 7 deletions codegen/templates/templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ func center(width int, pad string, s string) string {

func Funcs() template.FuncMap {
return template.FuncMap{
"ucFirst": ucFirst,
"lcFirst": lcFirst,
"ucFirst": UcFirst,
"lcFirst": LcFirst,
"quote": strconv.Quote,
"rawQuote": rawQuote,
"dump": Dump,
Expand All @@ -185,7 +185,7 @@ func Funcs() template.FuncMap {
}
}

func ucFirst(s string) string {
func UcFirst(s string) string {
if s == "" {
return ""
}
Expand All @@ -194,7 +194,7 @@ func ucFirst(s string) string {
return string(r)
}

func lcFirst(s string) string {
func LcFirst(s string) string {
if s == "" {
return ""
}
Expand Down Expand Up @@ -275,7 +275,7 @@ func ToGo(name string) string {
if strings.ToUpper(word) == word || strings.ToLower(word) == word {
// FOO or foo → Foo
// FOo → FOo
word = ucFirst(strings.ToLower(word))
word = UcFirst(strings.ToLower(word))
}
}
runes = append(runes, []rune(word)...)
Expand All @@ -297,13 +297,13 @@ func ToGoPrivate(name string) string {
word = strings.ToLower(info.Word)
} else {
// ITicket → iTicket
word = lcFirst(info.Word)
word = LcFirst(info.Word)
}
first = false
case info.MatchCommonInitial:
word = strings.ToUpper(word)
case !info.HasCommonInitial:
word = ucFirst(strings.ToLower(word))
word = UcFirst(strings.ToLower(word))
}
runes = append(runes, []rune(word)...)
})
Expand Down
3 changes: 2 additions & 1 deletion example/config/.gqlgen.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ exec:
model:
filename: models_gen.go
resolver:
filename: resolver.go
type: Resolver
layout: follow-schema
dir: .

models:
Todo: # Object
Expand Down
53 changes: 0 additions & 53 deletions example/config/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@

package config

import (
"context"
"fmt"
)

func New() Config {
c := Config{
Resolvers: &Resolver{
Expand All @@ -25,51 +20,3 @@ type Resolver struct {
todos []*Todo
nextID int
}

func (r *Resolver) Mutation() MutationResolver {
return &mutationResolver{r}
}
func (r *Resolver) Query() QueryResolver {
return &queryResolver{r}
}
func (r *Resolver) Todo() TodoResolver {
return &todoResolver{r}
}

type mutationResolver struct{ *Resolver }

func (r *mutationResolver) CreateTodo(ctx context.Context, input NewTodo) (*Todo, error) {
newID := r.nextID
r.nextID++

newTodo := &Todo{
DatabaseID: newID,
Description: input.Text,
}

r.todos = append(r.todos, newTodo)

return newTodo, nil
}

type queryResolver struct{ *Resolver }

func (r *queryResolver) Todos(ctx context.Context) ([]*Todo, error) {
return r.todos, nil
}

type todoResolver struct{ *Resolver }

func (r *todoResolver) Description(ctx context.Context, obj *Todo) (string, error) {
panic("implement me")
}

func (r *todoResolver) ID(ctx context.Context, obj *Todo) (string, error) {
if obj.ID != "" {
return obj.ID, nil
}

obj.ID = fmt.Sprintf("TODO:%d", obj.DatabaseID)

return obj.ID, nil
}
30 changes: 30 additions & 0 deletions example/config/schema_resolvers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// This file will be automatically regenerated based on the schema, any resolver implementations
// will be copied through when generating and any unknown code will be moved to the end.
package config

import (
"context"
)

func (r *mutationResolver) CreateTodo(ctx context.Context, input NewTodo) (*Todo, error) {
newID := r.nextID
r.nextID++

newTodo := &Todo{
DatabaseID: newID,
Description: input.Text,
}

r.todos = append(r.todos, newTodo)

return newTodo, nil
}
func (r *queryResolver) Todos(ctx context.Context) ([]*Todo, error) {
return r.todos, nil
}

func (r *Resolver) Mutation() MutationResolver { return &mutationResolver{r} }
func (r *Resolver) Query() QueryResolver { return &queryResolver{r} }

type mutationResolver struct{ *Resolver }
type queryResolver struct{ *Resolver }
22 changes: 22 additions & 0 deletions example/config/todo_resolvers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// This file will be automatically regenerated based on the schema, any resolver implementations
// will be copied through when generating and any unknown code will be moved to the end.
package config

import (
"context"
"fmt"
)

func (r *todoResolver) ID(ctx context.Context, obj *Todo) (string, error) {
if obj.ID != "" {
return obj.ID, nil
}

obj.ID = fmt.Sprintf("TODO:%d", obj.DatabaseID)

return obj.ID, nil
}

func (r *Resolver) Todo() TodoResolver { return &todoResolver{r} }

type todoResolver struct{ *Resolver }
87 changes: 87 additions & 0 deletions internal/rewrite/rewriter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package rewrite

import (
"fmt"
"go/ast"
"go/token"
"io/ioutil"

"golang.org/x/tools/go/packages"
)

type Rewriter struct {
pkg *packages.Package
files map[string]string
}

func New(importPath string) (*Rewriter, error) {
pkgs, err := packages.Load(&packages.Config{
Mode: packages.NeedSyntax | packages.NeedTypes,
}, importPath)
if err != nil {
return nil, err
}

return &Rewriter{
pkg: pkgs[0],
files: map[string]string{},
}, nil
}

func (r *Rewriter) getSource(start, end token.Pos) string {
startPos := r.pkg.Fset.Position(start)
endPos := r.pkg.Fset.Position(end)

if startPos.Filename != endPos.Filename {
panic("cant get source spanning multiple files")
}

file := r.getFile(startPos.Filename)
return file[startPos.Offset:endPos.Offset]
}

func (r *Rewriter) getFile(filename string) string {
if _, ok := r.files[filename]; !ok {
b, err := ioutil.ReadFile(filename)
if err != nil {
panic(fmt.Errorf("unable to load file, already exists: %s", err.Error()))
}

r.files[filename] = string(b)

}

return r.files[filename]
}

func (r *Rewriter) GetMethodBody(structname string, methodname string) string {
for _, f := range r.pkg.Syntax {
for _, d := range f.Decls {
switch d := d.(type) {
case *ast.FuncDecl:
if d.Name.Name != methodname {
continue
}
if d.Recv.List == nil {
continue
}
recv := d.Recv.List[0].Type
if star, isStar := d.Recv.List[0].Type.(*ast.StarExpr); isStar {
recv = star.X
}
ident, ok := recv.(*ast.Ident)
if !ok {
continue
}

if ident.Name != structname {
continue
}

return r.getSource(d.Body.Pos()+1, d.Body.End()-1)
}
}
}

return ""
}
22 changes: 22 additions & 0 deletions internal/rewrite/rewriter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package rewrite

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestRewriter(t *testing.T) {
r, err := New("github.com/99designs/gqlgen/internal/rewrite/testdata")
require.NoError(t, err)

body := r.GetMethodBody("Foo", "Method")
require.Equal(t, `
// leading comment
// field comment
m.Field++
// trailing comment
`, body)
}
14 changes: 14 additions & 0 deletions internal/rewrite/testdata/example.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package testdata

type Foo struct {
Field int
}

func (m *Foo) Method(arg int) {
// leading comment

// field comment
m.Field++

// trailing comment
}
Loading

0 comments on commit 6ec3650

Please sign in to comment.