From b07736ef8cfd03ba2a649c70cf8bfa3667102ecc Mon Sep 17 00:00:00 2001 From: Adam Scarr Date: Thu, 23 Aug 2018 15:27:07 +1000 Subject: [PATCH] Validate gopath when running gqlgen --- cmd/root.go | 11 ++++++ codegen/config.go | 16 ++------- codegen/config_test.go | 26 -------------- codegen/import_build.go | 10 ++---- internal/gopath/gopath.go | 37 ++++++++++++++++++++ internal/gopath/gopath_test.go | 62 ++++++++++++++++++++++++++++++++++ 6 files changed, 116 insertions(+), 46 deletions(-) create mode 100644 internal/gopath/gopath.go create mode 100644 internal/gopath/gopath_test.go diff --git a/cmd/root.go b/cmd/root.go index ddd02f18b4..8598acd315 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -6,6 +6,7 @@ import ( "log" "os" + "github.com/99designs/gqlgen/internal/gopath" "github.com/spf13/cobra" ) @@ -39,6 +40,16 @@ var rootCmd = &cobra.Command{ Long: `This is a library for quickly creating strictly typed graphql servers in golang. See https://gqlgen.com/ for a getting started guide.`, PersistentPreRun: func(cmd *cobra.Command, args []string) { + pwd, err := os.Getwd() + if err != nil { + fmt.Fprintf(os.Stderr, "unable to determine current workding dir: %s\n", err.Error()) + os.Exit(1) + } + + if !gopath.Contains(pwd) { + fmt.Fprintf(os.Stderr, "gqlgen must be run from inside your $GOPATH\n") + os.Exit(1) + } if verbose { log.SetFlags(0) } else { diff --git a/codegen/config.go b/codegen/config.go index 0ba1f65dab..8e2a5e74b3 100644 --- a/codegen/config.go +++ b/codegen/config.go @@ -8,6 +8,7 @@ import ( "path/filepath" "strings" + "github.com/99designs/gqlgen/internal/gopath" "github.com/pkg/errors" "github.com/vektah/gqlparser/ast" "gopkg.in/yaml.v2" @@ -107,22 +108,11 @@ func (c *PackageConfig) normalize() error { } func (c *PackageConfig) ImportPath() string { - dir := filepath.ToSlash(c.Dir()) - for _, gopath := range filepath.SplitList(build.Default.GOPATH) { - gopath = filepath.ToSlash(gopath) + "/src/" - if len(gopath) > len(dir) { - continue - } - if strings.EqualFold(gopath, dir[0:len(gopath)]) { - dir = dir[len(gopath):] - break - } - } - return dir + return gopath.MustDir2Import(c.Dir()) } func (c *PackageConfig) Dir() string { - return filepath.ToSlash(filepath.Dir(c.Filename)) + return filepath.Dir(c.Filename) } func (c *PackageConfig) Check() error { diff --git a/codegen/config_test.go b/codegen/config_test.go index f4260d6dfb..f990decf44 100644 --- a/codegen/config_test.go +++ b/codegen/config_test.go @@ -1,12 +1,10 @@ package codegen import ( - "go/build" "os" "path/filepath" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -58,27 +56,3 @@ func TestLoadDefaultConfig(t *testing.T) { require.True(t, os.IsNotExist(err)) }) } - -func Test_fullPackageName(t *testing.T) { - origBuildContext := build.Default - defer func() { build.Default = origBuildContext }() - - t.Run("gopath longer than package name", func(t *testing.T) { - p := PackageConfig{Filename: "/b/src/y/foo/bar/baz.go"} - build.Default.GOPATH = "/a/src/xxxxxxxxxxxxxxxxxxxxxxxx:/b/src/y" - var got string - ok := assert.NotPanics(t, func() { got = p.ImportPath() }) - if ok { - assert.Equal(t, "/b/src/y/foo/bar", got) - } - }) - t.Run("stop searching on first hit", func(t *testing.T) { - p := PackageConfig{Filename: "/a/src/x/foo/bar/baz.go"} - build.Default.GOPATH = "/a/src/x:/b/src/y" - var got string - ok := assert.NotPanics(t, func() { got = p.ImportPath() }) - if ok { - assert.Equal(t, "/a/src/x/foo/bar", got) - } - }) -} diff --git a/codegen/import_build.go b/codegen/import_build.go index d7d9677445..d8f2a2def2 100644 --- a/codegen/import_build.go +++ b/codegen/import_build.go @@ -5,12 +5,11 @@ import ( "go/build" "sort" "strconv" - "strings" - // Import and ignore the ambient imports listed below so dependency managers // don't prune unused code for us. Both lists should be kept in sync. _ "github.com/99designs/gqlgen/graphql" _ "github.com/99designs/gqlgen/graphql/introspection" + "github.com/99designs/gqlgen/internal/gopath" _ "github.com/vektah/gqlparser" _ "github.com/vektah/gqlparser/ast" ) @@ -55,7 +54,8 @@ func (s *Imports) add(path string) *Import { return nil } - if stringHasSuffixFold(s.destDir, path) { + // if we are referencing our own package we dont need an import + if gopath.MustDir2Import(s.destDir) == path { return nil } @@ -77,10 +77,6 @@ func (s *Imports) add(path string) *Import { return imp } -func stringHasSuffixFold(s, suffix string) bool { - return len(s) >= len(suffix) && strings.EqualFold(s[len(s)-len(suffix):], suffix) -} - func (s Imports) finalize() []*Import { // ensure stable ordering by sorting sort.Slice(s.imports, func(i, j int) bool { diff --git a/internal/gopath/gopath.go b/internal/gopath/gopath.go new file mode 100644 index 0000000000..c9b66167a3 --- /dev/null +++ b/internal/gopath/gopath.go @@ -0,0 +1,37 @@ +package gopath + +import ( + "fmt" + "go/build" + "path/filepath" + "strings" +) + +var NotFound = fmt.Errorf("not on GOPATH") + +// Contains returns true if the given directory is in the GOPATH +func Contains(dir string) bool { + _, err := Dir2Import(dir) + return err == nil +} + +// Dir2Import takes an *absolute* path and returns a golang import path for the package, and returns an error if it isn't on the gopath +func Dir2Import(dir string) (string, error) { + dir = filepath.ToSlash(dir) + for _, gopath := range filepath.SplitList(build.Default.GOPATH) { + gopath = filepath.ToSlash(filepath.Join(gopath, "src")) + if len(gopath) < len(dir) && strings.EqualFold(gopath, dir[0:len(gopath)]) { + return dir[len(gopath)+1:], nil + } + } + return "", NotFound +} + +// MustDir2Import takes an *absolute* path and returns a golang import path for the package, and panics if it isn't on the gopath +func MustDir2Import(dir string) string { + pkg, err := Dir2Import(dir) + if err != nil { + panic(err) + } + return pkg +} diff --git a/internal/gopath/gopath_test.go b/internal/gopath/gopath_test.go new file mode 100644 index 0000000000..847ad1e856 --- /dev/null +++ b/internal/gopath/gopath_test.go @@ -0,0 +1,62 @@ +package gopath + +import ( + "go/build" + "runtime" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestContains(t *testing.T) { + origBuildContext := build.Default + defer func() { build.Default = origBuildContext }() + + if runtime.GOOS == "windows" { + build.Default.GOPATH = `C:\go;C:\Users\user\go` + + assert.True(t, Contains(`C:\go\src\github.com\vektah\gqlgen`)) + assert.True(t, Contains(`C:\go\src\fpp`)) + assert.True(t, Contains(`C:/go/src/github.com/vektah/gqlgen`)) + assert.True(t, Contains(`C:\Users\user\go\src\foo`)) + assert.False(t, Contains(`C:\tmp`)) + assert.False(t, Contains(`C:\Users\user`)) + assert.False(t, Contains(`C:\Users\another\go`)) + } else { + build.Default.GOPATH = "/go:/home/user/go" + + assert.True(t, Contains("/go/src/github.com/vektah/gqlgen")) + assert.True(t, Contains("/go/src/foo")) + assert.True(t, Contains("/home/user/go/src/foo")) + assert.False(t, Contains("/tmp")) + assert.False(t, Contains("/home/user")) + assert.False(t, Contains("/home/another/go")) + } +} + +func TestDir2Package(t *testing.T) { + origBuildContext := build.Default + defer func() { build.Default = origBuildContext }() + + if runtime.GOOS == "windows" { + build.Default.GOPATH = "C:/xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx;C:/a/y;C:/b/" + + assert.Equal(t, "foo/bar", MustDir2Import("C:/a/y/src/foo/bar")) + assert.Equal(t, "foo/bar", MustDir2Import(`C:\a\y\src\foo\bar`)) + assert.Equal(t, "foo/bar", MustDir2Import("C:/b/src/foo/bar")) + assert.Equal(t, "foo/bar", MustDir2Import(`C:\b\src\foo\bar`)) + + assert.PanicsWithValue(t, NotFound, func() { + MustDir2Import("C:/tmp/foo") + }) + } else { + build.Default.GOPATH = "/xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx:/a/y:/b/" + + assert.Equal(t, "foo/bar", MustDir2Import("/a/y/src/foo/bar")) + assert.Equal(t, "foo/bar", MustDir2Import("/b/src/foo/bar")) + + assert.PanicsWithValue(t, NotFound, func() { + MustDir2Import("/tmp/foo") + }) + } +}