diff --git a/cmd/wire/main.go b/cmd/wire/main.go index e70a5895..ab9a193c 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -18,14 +18,12 @@ package main import ( + "context" "errors" "fmt" - "go/build" "go/token" "go/types" - "io/ioutil" "os" - "path/filepath" "reflect" "sort" "strconv" @@ -74,25 +72,17 @@ func generate(pkg string) error { if err != nil { return err } - pkgInfo, err := build.Default.Import(pkg, wd, build.FindOnly) - if err != nil { - return err - } - out, errs := wire.Generate(&build.Default, wd, pkg) + out, errs := wire.Generate(context.Background(), wd, os.Environ(), pkg) if len(errs) > 0 { logErrors(errs) return errors.New("generate failed") } - if len(out) == 0 { + if len(out.Content) == 0 { // No Wire directives, don't write anything. fmt.Fprintln(os.Stderr, "wire: no injector found for", pkg) return nil } - p := filepath.Join(pkgInfo.Dir, "wire_gen.go") - if err := ioutil.WriteFile(p, out, 0666); err != nil { - return err - } - return nil + return out.Commit() } // show runs the show subcommand. @@ -106,7 +96,7 @@ func show(pkgs ...string) error { if err != nil { return err } - info, errs := wire.Load(&build.Default, wd, pkgs) + info, errs := wire.Load(context.Background(), wd, os.Environ(), pkgs) if info != nil { keys := make([]wire.ProviderSetID, 0, len(info.Sets)) for k := range info.Sets { @@ -185,7 +175,7 @@ func check(pkgs ...string) error { if err != nil { return err } - _, errs := wire.Load(&build.Default, wd, pkgs) + _, errs := wire.Load(context.Background(), wd, os.Environ(), pkgs) if len(errs) > 0 { logErrors(errs) return errors.New("error loading packages") diff --git a/internal/wire/parse.go b/internal/wire/parse.go index b701b40f..3eeb01d8 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -15,6 +15,7 @@ package wire import ( + "context" "errors" "fmt" "go/ast" @@ -190,11 +191,19 @@ type Value struct { info *types.Info } -// Load finds all the provider sets in the given packages, as well as -// the provider sets' transitive dependencies. It may return both errors -// and Info. -func Load(bctx *build.Context, wd string, pkgs []string) (*Info, []error) { - prog, errs := load(bctx, wd, pkgs) +// Load finds all the provider sets in the packages that match the given +// patterns, as well as the provider sets' transitive dependencies. It +// may return both errors and Info. The patterns are defined by the +// underlying build system. For the go tool, this is described at +// https://golang.org/cmd/go/#hdr-Package_lists_and_patterns +// +// wd is the working directory and env is the set of environment +// variables to use when loading the packages specified by patterns. If +// env is nil or empty, it is interpreted as an empty set of variables. +// In case of duplicate environment variables, the last one in the list +// takes precedence. +func Load(ctx context.Context, wd string, env []string, patterns []string) (*Info, []error) { + prog, errs := load(ctx, wd, env, patterns) if len(errs) > 0 { return nil, errs } @@ -275,12 +284,22 @@ func Load(bctx *build.Context, wd string, pkgs []string) (*Info, []error) { return info, ec.errors } -// load typechecks the packages, including function body type checking -// for the packages directly named. -func load(bctx *build.Context, wd string, pkgs []string) (*loader.Program, []error) { +// load typechecks the packages that match the given patterns, including +// function body type checking for the packages that directly match. The +// patterns are defined by the underlying build system. For the go tool, +// this is described at +// https://golang.org/cmd/go/#hdr-Package_lists_and_patterns +// +// wd is the working directory and env is the set of environment +// variables to use when loading the packages specified by patterns. If +// env is nil or empty, it is interpreted as an empty set of variables. +// In case of duplicate environment variables, the last one in the list +// takes precedence. +func load(ctx context.Context, wd string, env []string, patterns []string) (*loader.Program, []error) { + bctx := buildContextFromEnv(env) var foundPkgs []*build.Package ec := new(errorCollector) - for _, name := range pkgs { + for _, name := range patterns { p, err := bctx.Import(name, wd, build.FindOnly) if err != nil { ec.add(err) @@ -320,7 +339,7 @@ func load(bctx *build.Context, wd string, pkgs []string) (*loader.Program, []err return pkg, err }, } - for _, name := range pkgs { + for _, name := range patterns { conf.Import(name) } @@ -334,6 +353,35 @@ func load(bctx *build.Context, wd string, pkgs []string) (*loader.Program, []err return prog, nil } +func buildContextFromEnv(env []string) *build.Context { + // TODO(#78): Remove this function in favor of using go/packages, + // which does not need a *build.Context. + + getenv := func(name string) string { + for i := len(env) - 1; i >= 0; i-- { + if strings.HasPrefix(env[i], name+"=") { + return env[i][len(name)+1:] + } + } + return "" + } + bctx := new(build.Context) + *bctx = build.Default + if v := getenv("GOARCH"); v != "" { + bctx.GOARCH = v + } + if v := getenv("GOOS"); v != "" { + bctx.GOOS = v + } + if v := getenv("GOROOT"); v != "" { + bctx.GOROOT = v + } + if v := getenv("GOPATH"); v != "" { + bctx.GOPATH = v + } + return bctx +} + func importPathInPkgList(pkgs []*build.Package, path string) bool { for _, p := range pkgs { if path == p.ImportPath { diff --git a/internal/wire/testdata/BuildTagsRelativePkg/pkg b/internal/wire/testdata/BuildTagsRelativePkg/pkg index 6b704cdf..b10f96cf 100644 --- a/internal/wire/testdata/BuildTagsRelativePkg/pkg +++ b/internal/wire/testdata/BuildTagsRelativePkg/pkg @@ -1 +1 @@ -./example.com/foo +./foo diff --git a/internal/wire/testdata/Cycle/want/wire_errs.txt b/internal/wire/testdata/Cycle/want/wire_errs.txt index 1eb88212..2d86ab93 100644 --- a/internal/wire/testdata/Cycle/want/wire_errs.txt +++ b/internal/wire/testdata/Cycle/want/wire_errs.txt @@ -1,4 +1,4 @@ -/wire_gopath/src/example.com/foo/wire.go:x:y: cycle for example.com/foo.Bar: +example.com/foo/wire.go:x:y: cycle for example.com/foo.Bar: example.com/foo.Bar (example.com/foo.provideBar) -> example.com/foo.Foo (example.com/foo.provideFoo) -> example.com/foo.Baz (example.com/foo.provideBaz) -> diff --git a/internal/wire/testdata/EmptyVar/want/wire_errs.txt b/internal/wire/testdata/EmptyVar/want/wire_errs.txt index c197b4e4..d7e1d21d 100644 --- a/internal/wire/testdata/EmptyVar/want/wire_errs.txt +++ b/internal/wire/testdata/EmptyVar/want/wire_errs.txt @@ -1 +1 @@ -/wire_gopath/src/example.com/foo/wire.go:x:y: var example.com/foo.myFakeSet struct{} is not a provider or a provider set \ No newline at end of file +example.com/foo/wire.go:x:y: var example.com/foo.myFakeSet struct{} is not a provider or a provider set \ No newline at end of file diff --git a/internal/wire/testdata/InjectInputConflict/want/wire_errs.txt b/internal/wire/testdata/InjectInputConflict/want/wire_errs.txt index 1bcae60e..f3b2c2c5 100644 --- a/internal/wire/testdata/InjectInputConflict/want/wire_errs.txt +++ b/internal/wire/testdata/InjectInputConflict/want/wire_errs.txt @@ -1 +1 @@ -/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectBar: input of example.com/foo.Foo conflicts with provider provideFoo at /wire_gopath/src/example.com/foo/foo.go:x:y \ No newline at end of file +example.com/foo/wire.go:x:y: inject injectBar: input of example.com/foo.Foo conflicts with provider provideFoo at example.com/foo/foo.go:x:y \ No newline at end of file diff --git a/internal/wire/testdata/InjectorMissingCleanup/want/wire_errs.txt b/internal/wire/testdata/InjectorMissingCleanup/want/wire_errs.txt index 8919b7d0..a72614a1 100644 --- a/internal/wire/testdata/InjectorMissingCleanup/want/wire_errs.txt +++ b/internal/wire/testdata/InjectorMissingCleanup/want/wire_errs.txt @@ -1 +1 @@ -/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectFoo: provider for example.com/foo.Foo returns cleanup but injection does not return cleanup function \ No newline at end of file +example.com/foo/wire.go:x:y: inject injectFoo: provider for example.com/foo.Foo returns cleanup but injection does not return cleanup function \ No newline at end of file diff --git a/internal/wire/testdata/InjectorMissingError/want/wire_errs.txt b/internal/wire/testdata/InjectorMissingError/want/wire_errs.txt index 6b3e5783..7e2aad10 100644 --- a/internal/wire/testdata/InjectorMissingError/want/wire_errs.txt +++ b/internal/wire/testdata/InjectorMissingError/want/wire_errs.txt @@ -1 +1 @@ -/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectFoo: provider for example.com/foo.Foo returns error but injection not allowed to fail \ No newline at end of file +example.com/foo/wire.go:x:y: inject injectFoo: provider for example.com/foo.Foo returns error but injection not allowed to fail \ No newline at end of file diff --git a/internal/wire/testdata/InterfaceBindingDoesntImplement/want/wire_errs.txt b/internal/wire/testdata/InterfaceBindingDoesntImplement/want/wire_errs.txt index 2d86d8bc..647d06fe 100644 --- a/internal/wire/testdata/InterfaceBindingDoesntImplement/want/wire_errs.txt +++ b/internal/wire/testdata/InterfaceBindingDoesntImplement/want/wire_errs.txt @@ -1 +1 @@ -/wire_gopath/src/example.com/foo/wire.go:x:y: string does not implement example.com/foo.Fooer \ No newline at end of file +example.com/foo/wire.go:x:y: string does not implement example.com/foo.Fooer \ No newline at end of file diff --git a/internal/wire/testdata/InterfaceBindingInvalidArg0/want/wire_errs.txt b/internal/wire/testdata/InterfaceBindingInvalidArg0/want/wire_errs.txt index 76e59543..4f3e78fb 100644 --- a/internal/wire/testdata/InterfaceBindingInvalidArg0/want/wire_errs.txt +++ b/internal/wire/testdata/InterfaceBindingInvalidArg0/want/wire_errs.txt @@ -1 +1 @@ -/wire_gopath/src/example.com/foo/wire.go:x:y: first argument to Bind must be a pointer to an interface type; found string \ No newline at end of file +example.com/foo/wire.go:x:y: first argument to Bind must be a pointer to an interface type; found string \ No newline at end of file diff --git a/internal/wire/testdata/InterfaceBindingNotEnoughArgs/want/wire_errs.txt b/internal/wire/testdata/InterfaceBindingNotEnoughArgs/want/wire_errs.txt index c75e2cbd..022422f0 100644 --- a/internal/wire/testdata/InterfaceBindingNotEnoughArgs/want/wire_errs.txt +++ b/internal/wire/testdata/InterfaceBindingNotEnoughArgs/want/wire_errs.txt @@ -1 +1 @@ -/wire_gopath/src/example.com/foo/wire.go:x:y: too few arguments in call to wire.Bind \ No newline at end of file +example.com/foo/wire.go:x:y: too few arguments in call to wire.Bind \ No newline at end of file diff --git a/internal/wire/testdata/InterfaceValueDoesntImplement/want/wire_errs.txt b/internal/wire/testdata/InterfaceValueDoesntImplement/want/wire_errs.txt index 525859ef..5357ddfc 100644 --- a/internal/wire/testdata/InterfaceValueDoesntImplement/want/wire_errs.txt +++ b/internal/wire/testdata/InterfaceValueDoesntImplement/want/wire_errs.txt @@ -1 +1 @@ -/wire_gopath/src/example.com/foo/wire.go:x:y: string does not implement io.Reader \ No newline at end of file +example.com/foo/wire.go:x:y: string does not implement io.Reader \ No newline at end of file diff --git a/internal/wire/testdata/InterfaceValueInvalidArg0/want/wire_errs.txt b/internal/wire/testdata/InterfaceValueInvalidArg0/want/wire_errs.txt index b956fb48..89cc2e7e 100644 --- a/internal/wire/testdata/InterfaceValueInvalidArg0/want/wire_errs.txt +++ b/internal/wire/testdata/InterfaceValueInvalidArg0/want/wire_errs.txt @@ -1 +1 @@ -/wire_gopath/src/example.com/foo/wire.go:x:y: first argument to InterfaceValue must be a pointer to an interface type; found string \ No newline at end of file +example.com/foo/wire.go:x:y: first argument to InterfaceValue must be a pointer to an interface type; found string \ No newline at end of file diff --git a/internal/wire/testdata/InterfaceValueNotEnoughArgs/want/wire_errs.txt b/internal/wire/testdata/InterfaceValueNotEnoughArgs/want/wire_errs.txt index 5c75536a..f203dbd3 100644 --- a/internal/wire/testdata/InterfaceValueNotEnoughArgs/want/wire_errs.txt +++ b/internal/wire/testdata/InterfaceValueNotEnoughArgs/want/wire_errs.txt @@ -1 +1 @@ -/wire_gopath/src/example.com/foo/wire.go:x:y: too few arguments in call to wire.InterfaceValue \ No newline at end of file +example.com/foo/wire.go:x:y: too few arguments in call to wire.InterfaceValue \ No newline at end of file diff --git a/internal/wire/testdata/MultipleBindings/want/wire_errs.txt b/internal/wire/testdata/MultipleBindings/want/wire_errs.txt index b9f66dd8..9379f9c2 100644 --- a/internal/wire/testdata/MultipleBindings/want/wire_errs.txt +++ b/internal/wire/testdata/MultipleBindings/want/wire_errs.txt @@ -1,41 +1,41 @@ -/wire_gopath/src/example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Foo +example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Foo current: -<- provider "provideFooAgain" (/wire_gopath/src/example.com/foo/foo.go:x:y) +<- provider "provideFooAgain" (example.com/foo/foo.go:x:y) previous: -<- provider "provideFoo" (/wire_gopath/src/example.com/foo/foo.go:x:y) +<- provider "provideFoo" (example.com/foo/foo.go:x:y) -/wire_gopath/src/example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Foo +example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Foo current: -<- provider "provideFoo" (/wire_gopath/src/example.com/foo/foo.go:x:y) +<- provider "provideFoo" (example.com/foo/foo.go:x:y) previous: -<- provider "provideFoo" (/wire_gopath/src/example.com/foo/foo.go:x:y) -<- provider set "Set" (/wire_gopath/src/example.com/foo/foo.go:x:y) +<- provider "provideFoo" (example.com/foo/foo.go:x:y) +<- provider set "Set" (example.com/foo/foo.go:x:y) -/wire_gopath/src/example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Foo +example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Foo current: -<- provider "provideFoo" (/wire_gopath/src/example.com/foo/foo.go:x:y) +<- provider "provideFoo" (example.com/foo/foo.go:x:y) previous: -<- provider "provideFoo" (/wire_gopath/src/example.com/foo/foo.go:x:y) -<- provider set "Set" (/wire_gopath/src/example.com/foo/foo.go:x:y) -<- provider set "SuperSet" (/wire_gopath/src/example.com/foo/foo.go:x:y) +<- provider "provideFoo" (example.com/foo/foo.go:x:y) +<- provider set "Set" (example.com/foo/foo.go:x:y) +<- provider set "SuperSet" (example.com/foo/foo.go:x:y) -/wire_gopath/src/example.com/foo/foo.go:x:y: SetWithDuplicateBindings has multiple bindings for example.com/foo.Foo +example.com/foo/foo.go:x:y: SetWithDuplicateBindings has multiple bindings for example.com/foo.Foo current: -<- provider "provideFoo" (/wire_gopath/src/example.com/foo/foo.go:x:y) -<- provider set "Set" (/wire_gopath/src/example.com/foo/foo.go:x:y) -<- provider set "SuperSet" (/wire_gopath/src/example.com/foo/foo.go:x:y) +<- provider "provideFoo" (example.com/foo/foo.go:x:y) +<- provider set "Set" (example.com/foo/foo.go:x:y) +<- provider set "SuperSet" (example.com/foo/foo.go:x:y) previous: -<- provider "provideFoo" (/wire_gopath/src/example.com/foo/foo.go:x:y) -<- provider set "Set" (/wire_gopath/src/example.com/foo/foo.go:x:y) +<- provider "provideFoo" (example.com/foo/foo.go:x:y) +<- provider set "Set" (example.com/foo/foo.go:x:y) -/wire_gopath/src/example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Foo +example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Foo current: -<- wire.Value (/wire_gopath/src/example.com/foo/wire.go:x:y) +<- wire.Value (example.com/foo/wire.go:x:y) previous: -<- provider "provideFoo" (/wire_gopath/src/example.com/foo/foo.go:x:y) +<- provider "provideFoo" (example.com/foo/foo.go:x:y) -/wire_gopath/src/example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Bar +example.com/foo/wire.go:x:y: wire.Build has multiple bindings for example.com/foo.Bar current: -<- wire.Bind (/wire_gopath/src/example.com/foo/wire.go:x:y) +<- wire.Bind (example.com/foo/wire.go:x:y) previous: -<- provider "provideBar" (/wire_gopath/src/example.com/foo/foo.go:x:y) \ No newline at end of file +<- provider "provideBar" (example.com/foo/foo.go:x:y) \ No newline at end of file diff --git a/internal/wire/testdata/MultipleMissingInputs/want/wire_errs.txt b/internal/wire/testdata/MultipleMissingInputs/want/wire_errs.txt index 07800728..f46d55c6 100644 --- a/internal/wire/testdata/MultipleMissingInputs/want/wire_errs.txt +++ b/internal/wire/testdata/MultipleMissingInputs/want/wire_errs.txt @@ -1,12 +1,12 @@ -/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectMissingOutputType: no provider found for example.com/foo.Foo, output of injector +example.com/foo/wire.go:x:y: inject injectMissingOutputType: no provider found for example.com/foo.Foo, output of injector -/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectMultipleMissingTypes: no provider found for example.com/foo.Foo -needed by example.com/foo.Baz in provider "provideBaz" (/wire_gopath/src/example.com/foo/foo.go:x:y) +example.com/foo/wire.go:x:y: inject injectMultipleMissingTypes: no provider found for example.com/foo.Foo +needed by example.com/foo.Baz in provider "provideBaz" (example.com/foo/foo.go:x:y) -/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectMultipleMissingTypes: no provider found for example.com/foo.Bar -needed by example.com/foo.Baz in provider "provideBaz" (/wire_gopath/src/example.com/foo/foo.go:x:y) +example.com/foo/wire.go:x:y: inject injectMultipleMissingTypes: no provider found for example.com/foo.Bar +needed by example.com/foo.Baz in provider "provideBaz" (example.com/foo/foo.go:x:y) -/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectMissingRecursiveType: no provider found for example.com/foo.Foo -needed by example.com/foo.Zip in provider "provideZip" (/wire_gopath/src/example.com/foo/foo.go:x:y) -needed by example.com/foo.Zap in provider "provideZap" (/wire_gopath/src/example.com/foo/foo.go:x:y) -needed by example.com/foo.Zop in provider "provideZop" (/wire_gopath/src/example.com/foo/foo.go:x:y) \ No newline at end of file +example.com/foo/wire.go:x:y: inject injectMissingRecursiveType: no provider found for example.com/foo.Foo +needed by example.com/foo.Zip in provider "provideZip" (example.com/foo/foo.go:x:y) +needed by example.com/foo.Zap in provider "provideZap" (example.com/foo/foo.go:x:y) +needed by example.com/foo.Zop in provider "provideZop" (example.com/foo/foo.go:x:y) \ No newline at end of file diff --git a/internal/wire/testdata/NoImplicitInterface/want/wire_errs.txt b/internal/wire/testdata/NoImplicitInterface/want/wire_errs.txt index 2a5a901f..23b30e01 100644 --- a/internal/wire/testdata/NoImplicitInterface/want/wire_errs.txt +++ b/internal/wire/testdata/NoImplicitInterface/want/wire_errs.txt @@ -1 +1 @@ -/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectFooer: no provider found for example.com/foo.Fooer, output of injector \ No newline at end of file +example.com/foo/wire.go:x:y: inject injectFooer: no provider found for example.com/foo.Fooer, output of injector \ No newline at end of file diff --git a/internal/wire/testdata/UnexportedStruct/want/wire_errs.txt b/internal/wire/testdata/UnexportedStruct/want/wire_errs.txt index 36f9cc73..ae23f88c 100644 --- a/internal/wire/testdata/UnexportedStruct/want/wire_errs.txt +++ b/internal/wire/testdata/UnexportedStruct/want/wire_errs.txt @@ -1 +1 @@ -/wire_gopath/src/example.com/foo/wire.go:x:y: foo not exported by package bar \ No newline at end of file +example.com/foo/wire.go:x:y: foo not exported by package bar \ No newline at end of file diff --git a/internal/wire/testdata/UnexportedValue/want/wire_errs.txt b/internal/wire/testdata/UnexportedValue/want/wire_errs.txt index 1ad98025..07b5c342 100644 --- a/internal/wire/testdata/UnexportedValue/want/wire_errs.txt +++ b/internal/wire/testdata/UnexportedValue/want/wire_errs.txt @@ -1 +1 @@ -/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectedMessage: value string can't be used: uses unexported identifier privateMsg \ No newline at end of file +example.com/foo/wire.go:x:y: inject injectedMessage: value string can't be used: uses unexported identifier privateMsg \ No newline at end of file diff --git a/internal/wire/testdata/UnusedProviders/want/wire_errs.txt b/internal/wire/testdata/UnusedProviders/want/wire_errs.txt index ab32986b..3bfaac1e 100644 --- a/internal/wire/testdata/UnusedProviders/want/wire_errs.txt +++ b/internal/wire/testdata/UnusedProviders/want/wire_errs.txt @@ -1,7 +1,7 @@ -/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectBar: unused provider set "unusedSet" +example.com/foo/wire.go:x:y: inject injectBar: unused provider set "unusedSet" -/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectBar: unused provider "provideUnused" +example.com/foo/wire.go:x:y: inject injectBar: unused provider "provideUnused" -/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectBar: unused value of type string +example.com/foo/wire.go:x:y: inject injectBar: unused value of type string -/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectBar: unused interface binding to type example.com/foo.Fooer \ No newline at end of file +example.com/foo/wire.go:x:y: inject injectBar: unused interface binding to type example.com/foo.Fooer \ No newline at end of file diff --git a/internal/wire/testdata/ValueFromFunctionScope/want/wire_errs.txt b/internal/wire/testdata/ValueFromFunctionScope/want/wire_errs.txt index a7cbd58a..ed7a45dc 100644 --- a/internal/wire/testdata/ValueFromFunctionScope/want/wire_errs.txt +++ b/internal/wire/testdata/ValueFromFunctionScope/want/wire_errs.txt @@ -1 +1 @@ -/wire_gopath/src/example.com/foo/wire.go:x:y: inject injectBar: value int can't be used: f is not declared in package scope \ No newline at end of file +example.com/foo/wire.go:x:y: inject injectBar: value int can't be used: f is not declared in package scope \ No newline at end of file diff --git a/internal/wire/testdata/ValueIsInterfaceValue/want/wire_errs.txt b/internal/wire/testdata/ValueIsInterfaceValue/want/wire_errs.txt index 939a9726..19af8703 100644 --- a/internal/wire/testdata/ValueIsInterfaceValue/want/wire_errs.txt +++ b/internal/wire/testdata/ValueIsInterfaceValue/want/wire_errs.txt @@ -1 +1 @@ -/wire_gopath/src/example.com/foo/wire.go:x:y: argument to Value may not be an interface value (found io.Reader); use InterfaceValue instead \ No newline at end of file +example.com/foo/wire.go:x:y: argument to Value may not be an interface value (found io.Reader); use InterfaceValue instead \ No newline at end of file diff --git a/internal/wire/wire.go b/internal/wire/wire.go index a44da210..ab6983f0 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -18,13 +18,15 @@ package wire import ( "bytes" + "context" + "errors" "fmt" "go/ast" - "go/build" "go/format" "go/printer" "go/token" "go/types" + "io/ioutil" "path/filepath" "sort" "strconv" @@ -36,22 +38,50 @@ import ( "golang.org/x/tools/go/loader" ) +// GeneratedFile stores the content of a call to Generate and the +// desired on-disk location of the file. +type GeneratedFile struct { + Path string + Content []byte +} + +// Commit writes the generated file to disk. +func (gen GeneratedFile) Commit() error { + if len(gen.Content) == 0 { + return nil + } + return ioutil.WriteFile(gen.Path, gen.Content, 0666) +} + // Generate performs dependency injection for a single package, -// returning the gofmt'd Go source code. -func Generate(bctx *build.Context, wd string, pkg string) ([]byte, []error) { - prog, errs := load(bctx, wd, []string{pkg}) +// returning the gofmt'd Go source code. The package pattern is defined +// by the underlying build system. For the go tool, this is described at +// https://golang.org/cmd/go/#hdr-Package_lists_and_patterns +// +// wd is the working directory and env is the set of environment +// variables to use when loading the package specified by pkgPattern. If +// env is nil or empty, it is interpreted as an empty set of variables. +// In case of duplicate environment variables, the last one in the list +// takes precedence. +func Generate(ctx context.Context, wd string, env []string, pkgPattern string) (GeneratedFile, []error) { + prog, errs := load(ctx, wd, env, []string{pkgPattern}) if len(errs) > 0 { - return nil, errs + return GeneratedFile{}, errs } if len(prog.InitialPackages()) != 1 { // This is more of a violated precondition than anything else. - return nil, []error{fmt.Errorf("load: got %d packages", len(prog.InitialPackages()))} + return GeneratedFile{}, []error{fmt.Errorf("load: got %d packages", len(prog.InitialPackages()))} } pkgInfo := prog.InitialPackages()[0] + outDir, err := detectOutputDir(prog.Fset, pkgInfo.Files) + if err != nil { + return GeneratedFile{}, []error{fmt.Errorf("load: %v", err)} + } + outFname := filepath.Join(outDir, "wire_gen.go") g := newGen(prog, pkgInfo.Pkg.Path()) injectorFiles, errs := generateInjectors(g, pkgInfo) if len(errs) > 0 { - return nil, errs + return GeneratedFile{}, errs } copyNonInjectorDecls(g, injectorFiles, &pkgInfo.Info) goSrc := g.frame() @@ -59,9 +89,22 @@ func Generate(bctx *build.Context, wd string, pkg string) ([]byte, []error) { if err != nil { // This is likely a bug from a poorly generated source file. // Return an error and the unformatted source. - return goSrc, []error{err} + return GeneratedFile{Path: outFname, Content: goSrc}, []error{err} + } + return GeneratedFile{Path: outFname, Content: fmtSrc}, nil +} + +func detectOutputDir(fset *token.FileSet, files []*ast.File) (string, error) { + if len(files) == 0 { + return "", errors.New("no files to derive output directory from") + } + dir := filepath.Dir(fset.File(files[0].Package).Name()) + for _, f := range files[1:] { + if dir2 := filepath.Dir(fset.File(f.Package).Name()); dir2 != dir { + return "", fmt.Errorf("found conflicting directories %q and %q", dir, dir2) + } } - return fmtSrc, nil + return dir, nil } // generateInjectors generates the injectors for a given package. diff --git a/internal/wire/wire_test.go b/internal/wire/wire_test.go index 2c1327c9..a8e3bf34 100644 --- a/internal/wire/wire_test.go +++ b/internal/wire/wire_test.go @@ -16,21 +16,16 @@ package wire import ( "bytes" - "errors" + "context" "fmt" "go/build" "go/types" - "io" "io/ioutil" "os" "os/exec" "path/filepath" - "regexp" - "runtime" - "sort" "strings" "testing" - "time" "unicode" "unicode/utf8" @@ -63,29 +58,39 @@ func TestWire(t *testing.T) { } tests = append(tests, test) } - wd := filepath.Join(magicGOPATH(), "src") + var goToolPath string if *setup.Record { - if _, err := os.Stat(filepath.Join(build.Default.GOROOT, "bin", "go")); err != nil { + goToolPath = filepath.Join(build.Default.GOROOT, "bin", "go") + if _, err := os.Stat(goToolPath); err != nil { t.Fatal("go toolchain not available:", err) } } + ctx := context.Background() for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { t.Parallel() - // Run Wire from a fake build context. - bctx := test.buildContext() - gen, errs := Generate(bctx, wd, test.pkg) - if len(gen) > 0 { - defer t.Logf("wire_gen.go:\n%s", gen) + // Materialize a temporary GOPATH directory. + gopath, err := ioutil.TempDir("", "wire_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(gopath) + if err := test.materialize(gopath); err != nil { + t.Fatal(err) + } + wd := filepath.Join(gopath, "src", "example.com") + gen, errs := Generate(ctx, wd, append(os.Environ(), "GOPATH="+gopath), test.pkg) + if len(gen.Content) > 0 { + defer t.Logf("wire_gen.go:\n%s", gen.Content) } if len(errs) > 0 { gotErrStrings := make([]string, len(errs)) for i, e := range errs { - gotErrStrings[i] = scrubError(e.Error()) - t.Log(gotErrStrings[i]) + t.Log(e.Error()) + gotErrStrings[i] = scrubError(gopath, e.Error()) } if !test.wantWireError { t.Fatal("Did not expect errors. To -record an error, create want/wire_errs.txt.") @@ -105,26 +110,37 @@ func TestWire(t *testing.T) { if test.wantWireError { t.Fatal("wire succeeded; want error") } + outPathSane := true + if prefix := gopath + string(os.PathSeparator) + "src" + string(os.PathSeparator); !strings.HasPrefix(gen.Path, prefix) { + outPathSane = false + t.Errorf("suggested output path = %q; want to start with %q", gen.Path, prefix) + } if *setup.Record { // Record ==> Build the generated Wire code, // check that the program's output matches the // expected output, save wire output on // success. - if err := goBuildCheck(test, wd, bctx, gen); err != nil { + if !outPathSane { + return + } + if err := gen.Commit(); err != nil { + t.Fatalf("failed to write wire_gen.go to test GOPATH: %v", err) + } + if err := goBuildCheck(goToolPath, gopath, test); err != nil { t.Fatalf("go build check failed: %v", err) } - wireGenFile := filepath.Join(testRoot, test.name, "want", "wire_gen.go") - if err := ioutil.WriteFile(wireGenFile, gen, 0666); err != nil { - t.Fatalf("failed to write wire_gen.go file: %v", err) + testdataWireGenPath := filepath.Join(testRoot, test.name, "want", "wire_gen.go") + if err := ioutil.WriteFile(testdataWireGenPath, gen.Content, 0666); err != nil { + t.Fatalf("failed to record wire_gen.go to testdata: %v", err) } } else { // Replay ==> Load golden file and compare to // generated result. This check is meant to // detect non-deterministic behavior in the // Generate function. - if !bytes.Equal(gen, test.wantWireOutput) { - gotS, wantS := string(gen), string(test.wantWireOutput) + if !bytes.Equal(gen.Content, test.wantWireOutput) { + gotS, wantS := string(gen.Content), string(test.wantWireOutput) diff := cmp.Diff(strings.Split(gotS, "\n"), strings.Split(wantS, "\n")) t.Fatalf("wire output differs from golden file. If this change is expected, run with -record to update the wire_gen.go file.\n*** got:\n%s\n\n*** want:\n%s\n\n*** diff:\n%s", gotS, wantS, diff) } @@ -133,49 +149,27 @@ func TestWire(t *testing.T) { } } -func goBuildCheck(test *testCase, wd string, bctx *build.Context, gen []byte) error { - // Find the absolute import path, since test.pkg may be a relative - // import path. - genPkg, err := bctx.Import(test.pkg, wd, build.FindOnly) - if err != nil { - return err - } - - // Run a `go build` with the generated output. - gopath, err := ioutil.TempDir("", "wire_test") - if err != nil { - return err - } - defer os.RemoveAll(gopath) - if err := test.materialize(gopath); err != nil { - return err - } - if len(gen) > 0 { - genPath := filepath.Join(gopath, "src", filepath.FromSlash(genPkg.ImportPath), "wire_gen.go") - if err := ioutil.WriteFile(genPath, gen, 0666); err != nil { - return err - } - } +func goBuildCheck(goToolPath, gopath string, test *testCase) error { + // Write go.mod files for example.com and the wire package. + // TODO(#78): Move this to happen in materialize() once modules work. if err := writeGoMod(gopath); err != nil { return err } + + // Run `go build`. testExePath := filepath.Join(gopath, "bin", "testprog") - realBuildCtx := &build.Context{ - GOARCH: bctx.GOARCH, - GOOS: bctx.GOOS, - GOROOT: bctx.GOROOT, - GOPATH: gopath, - CgoEnabled: bctx.CgoEnabled, - Compiler: bctx.Compiler, - BuildTags: bctx.BuildTags, - ReleaseTags: bctx.ReleaseTags, - } - buildDir := filepath.Join(gopath, "src", genPkg.ImportPath) buildCmd := []string{"build", "-o", testExePath} if test.name == "Vendor" && os.Getenv("GO111MODULE") == "on" { buildCmd = append(buildCmd, "-mod=vendor") } - if err := runGo(realBuildCtx, buildDir, buildCmd...); err != nil { + buildCmd = append(buildCmd, test.pkg) + cmd := exec.Command(goToolPath, buildCmd...) + cmd.Dir = filepath.Join(gopath, "src", "example.com") + cmd.Env = append(os.Environ(), "GOPATH="+gopath) + if buildOut, err := cmd.CombinedOutput(); err != nil { + if len(buildOut) > 0 { + return fmt.Errorf("build: %v; output:\n%s", err, buildOut) + } return fmt.Errorf("build: %v", err) } @@ -332,24 +326,15 @@ func TestDisambiguate(t *testing.T) { func isIdent(s string) bool { if len(s) == 0 { - if s == "foo" { - panic("BREAK3") - } return false } r, i := utf8.DecodeRuneInString(s) if !unicode.IsLetter(r) && r != '_' { - if s == "foo" { - panic("BREAK2") - } return false } for i < len(s) { r, sz := utf8.DecodeRuneInString(s[i:]) if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' { - if s == "foo" { - panic("BREAK1") - } return false } i += sz @@ -357,6 +342,80 @@ func isIdent(s string) bool { return true } +// scrubError rewrites the given string to remove occurrences of GOPATH/src, +// rewrites OS-specific path separators to slashes, and any line/column +// information to a fixed ":x:y". For example, if the gopath parameter is +// "C:\GOPATH" and running on Windows, the string +// "C:\GOPATH\src\foo\bar.go:15:4" would be rewritten to "foo/bar.go:x:y". +func scrubError(gopath string, s string) string { + sb := new(strings.Builder) + query := gopath + string(os.PathSeparator) + "src" + string(os.PathSeparator) + for { + // Find next occurrence of source root. This indicates the next path to + // scrub. + start := strings.Index(s, query) + if start == -1 { + sb.WriteString(s) + break + } + + // Find end of file name (extension ".go"). + fileStart := start + len(query) + fileEnd := strings.Index(s[fileStart:], ".go") + if fileEnd == -1 { + // If no ".go" occurs to end of string, further searches will fail too. + // Break the loop. + sb.WriteString(s) + break + } + fileEnd += fileStart + 3 // Advance to end of extension. + + // Write out file name and advance scrub position. + file := s[fileStart:fileEnd] + if os.PathSeparator != '/' { + file = strings.Replace(file, string(os.PathSeparator), "/", -1) + } + sb.WriteString(s[:start]) + sb.WriteString(file) + s = s[fileEnd:] + + // Peek past to see if there is line/column info. + linecol, linecolLen := scrubLineColumn(s) + sb.WriteString(linecol) + s = s[linecolLen:] + } + return sb.String() +} + +func scrubLineColumn(s string) (replacement string, n int) { + if !strings.HasPrefix(s, ":") { + return "", 0 + } + // Skip first colon and run of digits. + for n++; len(s) > n && '0' <= s[n] && s[n] <= '9'; { + n++ + } + if n == 1 { + // No digits followed colon. + return "", 0 + } + + // Start on column part. + if !strings.HasPrefix(s[n:], ":") { + return ":x", n + } + lineEnd := n + // Skip second colon and run of digits. + for n++; len(s) > n && '0' <= s[n] && s[n] <= '9'; { + n++ + } + if n == lineEnd+1 { + // No digits followed second colon. + return ":x", lineEnd + } + return ":x:y", n +} + type testCase struct { name string pkg string @@ -367,14 +426,6 @@ type testCase struct { wantWireErrorStrings []string } -var scrubLineNumberAndPositionRegex = regexp.MustCompile("\\.go:[\\d]+:[\\d]+") -var scrubLineNumberRegex = regexp.MustCompile("\\.go:[\\d]+") - -func scrubError(s string) string { - s = scrubLineNumberAndPositionRegex.ReplaceAllString(s, ".go:x:y") - return scrubLineNumberRegex.ReplaceAllString(s, ".go:x") -} - // loadTestCase reads a test case from a directory. // // The directory structure is: @@ -395,7 +446,7 @@ func scrubError(s string) string { // missing if no errors expected. // Distinct errors are separated by a blank line, // and line numbers and line positions are scrubbed -// (e.g., "foo.go:52:8" --> "foo.go:x:y"). +// (e.g. "$GOPATH/src/foo.go:52:8" --> "foo.go:x:y"). // // wire_gen.go // verified output of wire from a test run with @@ -417,7 +468,7 @@ func loadTestCase(root string, wireGoSrc []byte) (*testCase, error) { wantWireError := err == nil var wantWireErrorStrings []string if wantWireError { - wantWireErrorStrings = strings.Split(scrubError(string(wireErrb)), "\n\n") + wantWireErrorStrings = strings.Split(string(wireErrb), "\n\n") } else { if !*setup.Record { wantWireOutput, err = ioutil.ReadFile(filepath.Join(root, "want", "wire_gen.go")) @@ -448,7 +499,7 @@ func loadTestCase(root string, wireGoSrc []byte) (*testCase, error) { if err != nil { return err } - goFiles[filepath.Join("example.com", rel)] = data + goFiles["example.com/"+filepath.ToSlash(rel)] = data return nil }) if err != nil { @@ -465,187 +516,11 @@ func loadTestCase(root string, wireGoSrc []byte) (*testCase, error) { }, nil } -func (test *testCase) buildContext() *build.Context { - return &build.Context{ - GOARCH: build.Default.GOARCH, - GOOS: build.Default.GOOS, - GOROOT: build.Default.GOROOT, - GOPATH: magicGOPATH(), - CgoEnabled: build.Default.CgoEnabled, - Compiler: build.Default.Compiler, - ReleaseTags: build.Default.ReleaseTags, - HasSubdir: test.hasSubdir, - ReadDir: test.readDir, - OpenFile: test.openFile, - IsDir: test.isDir, - } -} - -const ( - magicGOPATHUnix = "/wire_gopath" - magicGOPATHWindows = `C:\wire_gopath` -) - -func magicGOPATH() string { - if runtime.GOOS == "windows" { - return magicGOPATHWindows - } - - return magicGOPATHUnix -} - -func (test *testCase) hasSubdir(root, dir string) (rel string, ok bool) { - // Don't consult filesystem, just lexical. - - if dir == root { - return "", true - } - prefix := root - if !strings.HasSuffix(prefix, string(filepath.Separator)) { - prefix += string(filepath.Separator) - } - if !strings.HasPrefix(dir, prefix) { - return "", false - } - return filepath.ToSlash(dir[len(prefix):]), true -} - -func (test *testCase) resolve(path string) (resolved string, pathType int) { - subpath, isMagic := test.hasSubdir(magicGOPATH(), path) - if !isMagic { - return path, systemPath - } - if subpath == "src" { - return "", gopathRoot - } - const srcPrefix = "src/" - if !strings.HasPrefix(subpath, srcPrefix) { - return subpath, gopathRoot - } - return subpath[len(srcPrefix):], gopathSrc -} - -// Path types -const ( - systemPath = iota - gopathRoot - gopathSrc -) - -func (test *testCase) readDir(dir string) ([]os.FileInfo, error) { - rpath, pathType := test.resolve(dir) - switch { - case pathType == systemPath: - return ioutil.ReadDir(rpath) - case pathType == gopathRoot && rpath == "": - return []os.FileInfo{dirInfo{name: "src"}}, nil - case pathType == gopathSrc: - names := make([]string, 0, len(test.goFiles)) - prefix := rpath + string(filepath.Separator) - for name := range test.goFiles { - if strings.HasPrefix(name, prefix) { - names = append(names, name[len(prefix):]) - } - } - sort.Strings(names) - ents := make([]os.FileInfo, 0, len(names)) - for _, name := range names { - if i := strings.IndexRune(name, filepath.Separator); i != -1 { - // Directory - dirName := name[:i] - if len(ents) == 0 || ents[len(ents)-1].Name() != dirName { - ents = append(ents, dirInfo{name: dirName}) - } - continue - } - ents = append(ents, fileInfo{ - name: name, - size: int64(len(test.goFiles[name])), - }) - } - return ents, nil - default: - return nil, &os.PathError{ - Op: "open", - Path: dir, - Err: os.ErrNotExist, - } - } -} - -func (test *testCase) isDir(path string) bool { - rpath, pathType := test.resolve(path) - switch { - case pathType == systemPath: - info, err := os.Stat(rpath) - return err == nil && info.IsDir() - case pathType == gopathRoot && rpath == "": - return true - case pathType == gopathSrc: - prefix := rpath + string(filepath.Separator) - for name := range test.goFiles { - if strings.HasPrefix(name, prefix) { - return true - } - } - return false - default: - return false - } -} - -type dirInfo struct { - name string -} - -func (d dirInfo) Name() string { return d.name } -func (d dirInfo) Size() int64 { return 0 } -func (d dirInfo) Mode() os.FileMode { return os.ModeDir | os.ModePerm } -func (d dirInfo) ModTime() time.Time { return time.Unix(0, 0) } -func (d dirInfo) IsDir() bool { return true } -func (d dirInfo) Sys() interface{} { return nil } - -type fileInfo struct { - name string - size int64 -} - -func (f fileInfo) Name() string { return f.name } -func (f fileInfo) Size() int64 { return f.size } -func (f fileInfo) Mode() os.FileMode { return os.ModeDir | 0666 } -func (f fileInfo) ModTime() time.Time { return time.Unix(0, 0) } -func (f fileInfo) IsDir() bool { return false } -func (f fileInfo) Sys() interface{} { return nil } - -func (test *testCase) openFile(path string) (io.ReadCloser, error) { - rpath, pathType := test.resolve(path) - switch { - case pathType == systemPath: - return os.Open(path) - case pathType == gopathSrc: - content, ok := test.goFiles[rpath] - if !ok { - return nil, &os.PathError{ - Op: "open", - Path: path, - Err: errors.New("does not exist or is not a file"), - } - } - return ioutil.NopCloser(bytes.NewReader(content)), nil - default: - return nil, &os.PathError{ - Op: "open", - Path: path, - Err: errors.New("does not exist or is not a file"), - } - } -} - // materialize creates a new GOPATH at the given directory, which may or // may not exist. func (test *testCase) materialize(gopath string) error { for name, content := range test.goFiles { - dst := filepath.Join(gopath, "src", name) + dst := filepath.Join(gopath, "src", filepath.FromSlash(name)) if err := os.MkdirAll(filepath.Dir(dst), 0777); err != nil { return fmt.Errorf("materialize GOPATH: %v", err) } @@ -675,11 +550,11 @@ func (test *testCase) materialize(gopath string) error { // // ... (Dependency files copied) func writeGoMod(gopath string) error { - importPath := "example.com" - depPath := "github.com/google/go-cloud" + const importPath = "example.com" + const depPath = "github.com/google/go-cloud" depLoc := filepath.Join(gopath, "src", filepath.FromSlash(depPath)) example := fmt.Sprintf("module %s\n\nreplace %s => %s\n", importPath, depPath, depLoc) - gomod := filepath.Join(gopath, "src", importPath, "go.mod") + gomod := filepath.Join(gopath, "src", filepath.FromSlash(importPath), "go.mod") if err := ioutil.WriteFile(gomod, []byte(example), 0666); err != nil { return fmt.Errorf("generate go.mod for %s: %v", gomod, err) } @@ -688,25 +563,3 @@ func writeGoMod(gopath string) error { } return nil } - -// runGo runs a go command in dir. -func runGo(bctx *build.Context, dir string, args ...string) error { - exe := filepath.Join(bctx.GOROOT, "bin", "go") - c := exec.Command(exe, args...) - c.Env = append(os.Environ(), "GOROOT="+bctx.GOROOT, "GOARCH="+bctx.GOARCH, "GOOS="+bctx.GOOS, "GOPATH="+bctx.GOPATH) - c.Dir = dir - if bctx.CgoEnabled { - c.Env = append(c.Env, "CGO_ENABLED=1") - } else { - c.Env = append(c.Env, "CGO_ENABLED=0") - } - // TODO(someday): Set -compiler flag if needed. - out, err := c.CombinedOutput() - if err != nil { - if len(out) > 0 { - return fmt.Errorf("%v; output:\n%s", err, out) - } - return err - } - return nil -}