diff --git a/arguments/parser.go b/arguments/parser.go index 557d495..c6a2f19 100644 --- a/arguments/parser.go +++ b/arguments/parser.go @@ -1,6 +1,7 @@ package arguments import ( + "bytes" "encoding/json" "errors" "flag" @@ -9,87 +10,221 @@ import ( "path/filepath" "regexp" "strings" + "text/template" "unicode" ) -func New(args []string, workingDir string, evaler Evaler, stater Stater) (*ParsedArguments, error) { - if len(args) == 0 { - return nil, errors.New("argument parsing requires at least one argument") - } +type flagsForGenerate struct { + FakeNameTemplate *string +} - fs := flag.NewFlagSet("counterfeiter", flag.ContinueOnError) - fakeNameFlag := fs.String( - "fake-name", +func (f *flagsForGenerate) RegisterFlags(fs *flag.FlagSet) { + f.FakeNameTemplate = fs.String( + "fake-name-template", "", - "The name of the fake struct", + `A template for the names of the fake structs in a generate call. Example: "The{{.TargetName}}Imposter"`, ) +} - outputPathFlag := fs.String( - "o", +type flagsForNonGenerate struct { + FakeName *string + Package *bool +} + +func (f *flagsForNonGenerate) RegisterFlags(fs *flag.FlagSet) { + f.FakeName = fs.String( + "fake-name", "", - "The file or directory to which the generated fake will be written", + "The name of the fake struct", ) - packageFlag := fs.Bool( + f.Package = fs.Bool( "p", false, "Whether or not to generate a package shim", ) - generateFlag := fs.Bool( +} + +type sharedFlags struct { + Generate *bool + OutputPath *string + Header *string + Quiet *bool + Help *bool +} + +func (f *sharedFlags) RegisterFlags(fs *flag.FlagSet) { + f.OutputPath = fs.String( + "o", + "", + "The file or directory to which the generated fake will be written", + ) + + f.Generate = fs.Bool( "generate", false, "Identify all //counterfeiter:generate directives in the current working directory and generate fakes for them", ) - headerFlag := fs.String( + + f.Header = fs.String( "header", "", "A path to a file that should be used as a header for the generated fake", ) - quietFlag := fs.Bool( + + f.Quiet = fs.Bool( "q", false, "Suppress status statements", ) - helpFlag := fs.Bool( + + f.Help = fs.Bool( "help", false, "Display this help", ) +} + +type allFlags struct { + flagsForGenerate + flagsForNonGenerate + sharedFlags +} + +func (f *allFlags) RegisterFlags(fs *flag.FlagSet) { + f.flagsForGenerate.RegisterFlags(fs) + f.flagsForNonGenerate.RegisterFlags(fs) + f.sharedFlags.RegisterFlags(fs) +} + +type standardFlags struct { + flagsForNonGenerate + sharedFlags +} + +func (f *standardFlags) RegisterFlags(fs *flag.FlagSet) { + f.flagsForNonGenerate.RegisterFlags(fs) + f.sharedFlags.RegisterFlags(fs) +} + +type GenerateArgs struct { + OutputPath string + FakeNameTemplate *template.Template + Header string + Quiet bool +} + +func ParseGenerateMode(args []string) (bool, *GenerateArgs, error) { + if len(args) == 0 { + return false, nil, errors.New("argument parsing requires at least one argument") + } + + fs := flag.NewFlagSet("counterfeiter", flag.ContinueOnError) + flags := new(allFlags) + flags.RegisterFlags(fs) + + err := fs.Parse(args[1:]) + if err != nil { + return false, nil, err + } + + if *flags.Help { + return false, nil, errors.New(usage) + } + if !*flags.Generate { + return false, nil, nil + } + + fakeNameTemplate, err := template.New("counterfeiter").Parse(*flags.FakeNameTemplate) + if err != nil { + return false, nil, fmt.Errorf("error parsing fake-name-template: %w", err) + } + + return true, &GenerateArgs{ + OutputPath: *flags.OutputPath, + FakeNameTemplate: fakeNameTemplate, + Header: *flags.Header, + Quiet: *flags.Quiet, + }, nil +} + +type FakeNameTemplateArg struct { + TargetName string +} + +func New(args []string, workingDir string, generateArgs *GenerateArgs, evaler Evaler, stater Stater) (*ParsedArguments, error) { + if len(args) == 0 { + return nil, errors.New("argument parsing requires at least one argument") + } + + fs := flag.NewFlagSet("counterfeiter", flag.ContinueOnError) + flags := new(standardFlags) + flags.RegisterFlags(fs) err := fs.Parse(args[1:]) if err != nil { return nil, err } - if *helpFlag { + + if len(fs.Args()) == 0 && !*flags.Generate { return nil, errors.New(usage) } - if len(fs.Args()) == 0 && !*generateFlag { - return nil, errors.New(usage) + + header := *flags.Header + outputPath := *flags.OutputPath + quiet := *flags.Quiet + fakeName := *flags.FakeName + + if generateArgs != nil { + header = or(header, generateArgs.Header) + quiet = quiet || generateArgs.Quiet + } - packageMode := *packageFlag result := &ParsedArguments{ PrintToStdOut: any(args, "-"), - GenerateInterfaceAndShimFromPackageDirectory: packageMode, - GenerateMode: *generateFlag, - HeaderFile: *headerFlag, - Quiet: *quietFlag, - } - if *generateFlag { - return result, nil + GenerateInterfaceAndShimFromPackageDirectory: *flags.Package, + HeaderFile: header, + Quiet: quiet, } - err = result.parseSourcePackageDir(packageMode, workingDir, evaler, stater, fs.Args()) + + err = result.parseSourcePackageDir(*flags.Package, workingDir, evaler, stater, fs.Args()) if err != nil { return nil, err } - result.parseInterfaceName(packageMode, fs.Args()) - result.parseFakeName(packageMode, *fakeNameFlag, fs.Args()) - result.parseOutputPath(packageMode, workingDir, *outputPathFlag, fs.Args()) - result.parseDestinationPackageName(packageMode, fs.Args()) - result.parsePackagePath(packageMode, fs.Args()) + result.parseInterfaceName(*flags.Package, fs.Args()) + + if generateArgs != nil { + outputPath = or(outputPath, generateArgs.OutputPath) + + if fakeName == "" && generateArgs.FakeNameTemplate != nil { + fakeNameWriter := new(bytes.Buffer) + err = generateArgs.FakeNameTemplate.Execute( + fakeNameWriter, + FakeNameTemplateArg{TargetName: fixupUnexportedNames(result.InterfaceName)}, + ) + if err != nil { + return nil, fmt.Errorf("error evaluating fake-name-template: %w", err) + } + fakeName = fakeNameWriter.String() + } + } + result.parseFakeName(*flags.Package, fakeName, fs.Args()) + result.parseOutputPath(*flags.Package, workingDir, outputPath, fs.Args()) + result.parseDestinationPackageName(*flags.Package, fs.Args()) + result.parsePackagePath(*flags.Package, fs.Args()) return result, nil } +func or(opts ...string) string { + for _, s := range opts { + if s != "" { + return s + } + } + return "" +} + func (a *ParsedArguments) PrettyPrint() { b, _ := json.Marshal(a) fmt.Println(string(b)) @@ -207,7 +342,6 @@ type ParsedArguments struct { FakeImplName string // the name of the struct implementing the given interface PrintToStdOut bool - GenerateMode bool Quiet bool HeaderFile string diff --git a/arguments/parser_test.go b/arguments/parser_test.go index b231096..46a3d55 100644 --- a/arguments/parser_test.go +++ b/arguments/parser_test.go @@ -1,3 +1,4 @@ +//go:build !windows // +build !windows package arguments_test @@ -10,9 +11,11 @@ import ( "os" "path" "path/filepath" + "testing" + "text/template" "time" - "testing" + "github.com/onsi/gomega/gbytes" "github.com/maxbrunsfeld/counterfeiter/v6/arguments" @@ -22,10 +25,97 @@ import ( ) func TestParsingArguments(t *testing.T) { - spec.Run(t, "ParsingArguments", testParsingArguments, spec.Report(report.Terminal{})) + spec.Run(t, "ParseGenerateMode", testParseGenerateMode, spec.Report(report.Terminal{})) + spec.Run(t, "ParsingArguments", testNew, spec.Report(report.Terminal{})) +} + +func testParseGenerateMode(t *testing.T, when spec.G, it spec.S) { + it.Before(func() { + RegisterTestingT(t) + log.SetOutput(ioutil.Discard) + }) + + when("-generate is given", func() { + it("returns true", func() { + generateMode, generateArgs, err := arguments.ParseGenerateMode([]string{"counterfeiter", + "-generate", + "-o", "fake", + "-fake-name-template", `The{{.TargetName}}Imposter`, + "-header", "my-header", + "-q", + }) + Expect(err).NotTo(HaveOccurred()) + Expect(generateMode).To(BeTrue()) + Expect(generateArgs).NotTo(BeNil()) + Expect(generateArgs.OutputPath).To(Equal("fake")) + Expect(generateArgs.Header).To(Equal("my-header")) + Expect(generateArgs.Quiet).To(BeTrue()) + + Expect(generateArgs.FakeNameTemplate).NotTo(BeNil()) + nameWriter := gbytes.NewBuffer() + Expect( + generateArgs.FakeNameTemplate.Execute(nameWriter, struct{ TargetName string }{"MyType"}), + ).To(Succeed()) + Expect(string(nameWriter.Contents())).To(Equal("TheMyTypeImposter")) + }) + + when("the fake-name-template is invalid", func() { + it("errors", func() { + _, _, err := arguments.ParseGenerateMode([]string{"counterfeiter", + "-generate", + "-fake-name-template", `{{panic "boom"}}`, + }) + Expect(err).To(MatchError(ContainSubstring("fake-name-template"))) + }) + }) + }) + + when("-generate is not given", func() { + it("returns false and nil", func() { + generateMode, generateArgs, err := arguments.ParseGenerateMode([]string{"counterfeiter", + "-o", "fake", + "-fake-name", "Bob", + }) + Expect(err).NotTo(HaveOccurred()) + Expect(generateMode).To(BeFalse()) + Expect(generateArgs).To(BeNil()) + }) + }) + + when("no args are given", func() { + it("returns false and nil", func() { + generateMode, generateArgs, err := arguments.ParseGenerateMode([]string{"counterfeiter"}) + Expect(err).NotTo(HaveOccurred()) + Expect(generateMode).To(BeFalse()) + Expect(generateArgs).To(BeNil()) + }) + }) + + when("unknown flags are given", func() { + it("returns an error", func() { + generateMode, generateArgs, err := arguments.ParseGenerateMode([]string{"counterfeiter", "-generate", "-no-such-flag"}) + Expect(err).To(HaveOccurred()) + Expect(generateMode).To(BeFalse()) + Expect(generateArgs).To(BeNil()) + }) + }) + + when("-help is given", func() { + it("returns an error", func() { + generateMode, generateArgs, err := arguments.ParseGenerateMode([]string{"counterfeiter", "-generate", "-help"}) + Expect(err).To(HaveOccurred()) + Expect(generateMode).To(BeFalse()) + Expect(generateArgs).To(BeNil()) + + generateMode, generateArgs, err = arguments.ParseGenerateMode([]string{"counterfeiter", "-help"}) + Expect(err).To(HaveOccurred()) + Expect(generateMode).To(BeFalse()) + Expect(generateArgs).To(BeNil()) + }) + }) } -func testParsingArguments(t *testing.T, when spec.G, it spec.S) { +func testNew(t *testing.T, when spec.G, it spec.S) { var ( err error parsedArgs *arguments.ParsedArguments @@ -35,10 +125,6 @@ func testParsingArguments(t *testing.T, when spec.G, it spec.S) { stater arguments.Stater ) - justBefore := func() { - parsedArgs, err = arguments.New(args, workingDir, evaler, stater) - } - it.Before(func() { RegisterTestingT(t) log.SetOutput(ioutil.Discard) @@ -52,307 +138,464 @@ func testParsingArguments(t *testing.T, when spec.G, it spec.S) { } }) - when("when the -p flag is provided", func() { - it.Before(func() { - args = []string{"counterfeiter", "-p", "os"} - justBefore() - }) - - it("doesn't parse extraneous arguments", func() { - Expect(err).To(Succeed()) - Expect(parsedArgs.GenerateInterfaceAndShimFromPackageDirectory).To(BeTrue()) - Expect(parsedArgs.InterfaceName).To(Equal("")) - Expect(parsedArgs.FakeImplName).To(Equal("Os")) - }) + when("not in generate mode", func() { + justBefore := func() { + parsedArgs, err = arguments.New(args, workingDir, nil, evaler, stater) + } - when("given a stdlib package", func() { - it("sets arguments as expected", func() { - Expect(parsedArgs.SourcePackageDir).To(Equal("os")) - Expect(parsedArgs.OutputPath).To(Equal(path.Join(workingDir, "osshim", "os.go"))) - Expect(parsedArgs.DestinationPackageName).To(Equal("osshim")) + when("when the -p flag is provided", func() { + it.Before(func() { + args = []string{"counterfeiter", "-p", "os"} + justBefore() }) - }) - }) - - when("when a single argument is provided", func() { - it.Before(func() { - args = []string{"counterfeiter", "someonesinterfaces.AnInterface"} - justBefore() - }) - - it("sets PrintToStdOut to false", func() { - Expect(parsedArgs.PrintToStdOut).To(BeFalse()) - }) - - it("provides a name for the fake implementing the interface", func() { - Expect(parsedArgs.FakeImplName).To(Equal("FakeAnInterface")) - }) - it("provides a path for the interface source", func() { - Expect(parsedArgs.PackagePath).To(Equal("someonesinterfaces")) - }) + it("doesn't parse extraneous arguments", func() { + Expect(err).To(Succeed()) + Expect(parsedArgs.GenerateInterfaceAndShimFromPackageDirectory).To(BeTrue()) + Expect(parsedArgs.InterfaceName).To(Equal("")) + Expect(parsedArgs.FakeImplName).To(Equal("Os")) + }) - it("treats the last segment as the interface to counterfeit", func() { - Expect(parsedArgs.InterfaceName).To(Equal("AnInterface")) + when("given a stdlib package", func() { + it("sets arguments as expected", func() { + Expect(parsedArgs.SourcePackageDir).To(Equal("os")) + Expect(parsedArgs.OutputPath).To(Equal(path.Join(workingDir, "osshim", "os.go"))) + Expect(parsedArgs.DestinationPackageName).To(Equal("osshim")) + }) + }) }) - it("snake cases the filename for the output directory", func() { - Expect(parsedArgs.OutputPath).To(Equal( - filepath.Join( - workingDir, - "workspacefakes", - "fake_an_interface.go", - ), - )) - }) - }) - - when("when a single argument is provided with the output directory", func() { - it.Before(func() { - args = []string{"counterfeiter", "-o", "/tmp/foo", "io.Writer"} - justBefore() - }) + when("when a single argument is provided", func() { + it.Before(func() { + args = []string{"counterfeiter", "someonesinterfaces.AnInterface"} + justBefore() + }) - it("indicates to not print to stdout", func() { - Expect(parsedArgs.PrintToStdOut).To(BeFalse()) - }) + it("sets PrintToStdOut to false", func() { + Expect(parsedArgs.PrintToStdOut).To(BeFalse()) + }) - it("provides a name for the fake implementing the interface", func() { - Expect(parsedArgs.FakeImplName).To(Equal("FakeWriter")) - }) + it("provides a name for the fake implementing the interface", func() { + Expect(parsedArgs.FakeImplName).To(Equal("FakeAnInterface")) + }) - it("provides a path for the interface source", func() { - Expect(parsedArgs.PackagePath).To(Equal("io")) - }) + it("provides a path for the interface source", func() { + Expect(parsedArgs.PackagePath).To(Equal("someonesinterfaces")) + }) - it("treats the last segment as the interface to counterfeit", func() { - Expect(parsedArgs.InterfaceName).To(Equal("Writer")) - }) + it("treats the last segment as the interface to counterfeit", func() { + Expect(parsedArgs.InterfaceName).To(Equal("AnInterface")) + }) - it("copies the provided output path into the result", func() { - Expect(parsedArgs.OutputPath).To(Equal("/tmp/foo/fake_writer.go")) + it("snake cases the filename for the output directory", func() { + Expect(parsedArgs.OutputPath).To(Equal( + filepath.Join( + workingDir, + "workspacefakes", + "fake_an_interface.go", + ), + )) + }) }) - }) - when("when a single argument is provided with the output file", func() { - it.Before(func() { - args = []string{"counterfeiter", "-o", "/tmp/foo/fake_foo.go", "io.Writer"} - justBefore() - }) + when("when a single argument is provided with the output directory", func() { + it.Before(func() { + args = []string{"counterfeiter", "-o", "/tmp/foo", "io.Writer"} + justBefore() + }) - it("indicates to not print to stdout", func() { - Expect(parsedArgs.PrintToStdOut).To(BeFalse()) - }) + it("indicates to not print to stdout", func() { + Expect(parsedArgs.PrintToStdOut).To(BeFalse()) + }) - it("provides a name for the fake implementing the interface", func() { - Expect(parsedArgs.FakeImplName).To(Equal("FakeWriter")) - }) + it("provides a name for the fake implementing the interface", func() { + Expect(parsedArgs.FakeImplName).To(Equal("FakeWriter")) + }) - it("provides a path for the interface source", func() { - Expect(parsedArgs.PackagePath).To(Equal("io")) - }) + it("provides a path for the interface source", func() { + Expect(parsedArgs.PackagePath).To(Equal("io")) + }) - it("treats the last segment as the interface to counterfeit", func() { - Expect(parsedArgs.InterfaceName).To(Equal("Writer")) - }) + it("treats the last segment as the interface to counterfeit", func() { + Expect(parsedArgs.InterfaceName).To(Equal("Writer")) + }) - it("copies the provided output path into the result", func() { - Expect(parsedArgs.OutputPath).To(Equal("/tmp/foo/fake_foo.go")) + it("copies the provided output path into the result", func() { + Expect(parsedArgs.OutputPath).To(Equal("/tmp/foo/fake_writer.go")) + }) }) - }) - when("when two arguments are provided", func() { - it.Before(func() { - args = []string{"counterfeiter", "my/my5package", "MySpecialInterface"} - justBefore() - }) + when("when a single argument is provided with the output file", func() { + it.Before(func() { + args = []string{"counterfeiter", "-o", "/tmp/foo/fake_foo.go", "io.Writer"} + justBefore() + }) - it("indicates to not print to stdout", func() { - Expect(parsedArgs.PrintToStdOut).To(BeFalse()) - }) + it("indicates to not print to stdout", func() { + Expect(parsedArgs.PrintToStdOut).To(BeFalse()) + }) - it("provides a name for the fake implementing the interface", func() { - Expect(parsedArgs.FakeImplName).To(Equal("FakeMySpecialInterface")) - }) + it("provides a name for the fake implementing the interface", func() { + Expect(parsedArgs.FakeImplName).To(Equal("FakeWriter")) + }) - it("treats the second argument as the interface to counterfeit", func() { - Expect(parsedArgs.InterfaceName).To(Equal("MySpecialInterface")) - }) + it("provides a path for the interface source", func() { + Expect(parsedArgs.PackagePath).To(Equal("io")) + }) - it("snake cases the filename for the output directory", func() { - Expect(parsedArgs.OutputPath).To(Equal( - filepath.Join( - parsedArgs.SourcePackageDir, - "my5packagefakes", - "fake_my_special_interface.go", - ), - )) - }) + it("treats the last segment as the interface to counterfeit", func() { + Expect(parsedArgs.InterfaceName).To(Equal("Writer")) + }) - it("specifies the destination package name", func() { - Expect(parsedArgs.DestinationPackageName).To(Equal("my5packagefakes")) + it("copies the provided output path into the result", func() { + Expect(parsedArgs.OutputPath).To(Equal("/tmp/foo/fake_foo.go")) + }) }) - when("when the interface is unexported", func() { + when("when two arguments are provided", func() { it.Before(func() { - args = []string{"counterfeiter", "my/mypackage", "mySpecialInterface"} + args = []string{"counterfeiter", "my/my5package", "MySpecialInterface"} justBefore() }) - it("fixes up the fake name to be TitleCase", func() { + it("indicates to not print to stdout", func() { + Expect(parsedArgs.PrintToStdOut).To(BeFalse()) + }) + + it("provides a name for the fake implementing the interface", func() { Expect(parsedArgs.FakeImplName).To(Equal("FakeMySpecialInterface")) }) + it("treats the second argument as the interface to counterfeit", func() { + Expect(parsedArgs.InterfaceName).To(Equal("MySpecialInterface")) + }) + it("snake cases the filename for the output directory", func() { Expect(parsedArgs.OutputPath).To(Equal( filepath.Join( parsedArgs.SourcePackageDir, - "mypackagefakes", + "my5packagefakes", "fake_my_special_interface.go", ), )) }) - }) - when("the source directory", func() { - it("should be an absolute path", func() { - Expect(filepath.IsAbs(parsedArgs.SourcePackageDir)).To(BeTrue()) + it("specifies the destination package name", func() { + Expect(parsedArgs.DestinationPackageName).To(Equal("my5packagefakes")) }) - when("when the first arg is a path to a file", func() { + when("when the interface is unexported", func() { it.Before(func() { - stater = func(filename string) (os.FileInfo, error) { - return fakeFileInfo(filename, false), nil - } + args = []string{"counterfeiter", "my/mypackage", "mySpecialInterface"} justBefore() }) - it("should be the directory containing the file", func() { - Expect(parsedArgs.SourcePackageDir).ToNot(ContainSubstring("something.go")) + it("fixes up the fake name to be TitleCase", func() { + Expect(parsedArgs.FakeImplName).To(Equal("FakeMySpecialInterface")) + }) + + it("snake cases the filename for the output directory", func() { + Expect(parsedArgs.OutputPath).To(Equal( + filepath.Join( + parsedArgs.SourcePackageDir, + "mypackagefakes", + "fake_my_special_interface.go", + ), + )) + }) + }) + + when("the source directory", func() { + it("should be an absolute path", func() { + Expect(filepath.IsAbs(parsedArgs.SourcePackageDir)).To(BeTrue()) + }) + + when("when the first arg is a path to a file", func() { + it.Before(func() { + stater = func(filename string) (os.FileInfo, error) { + return fakeFileInfo(filename, false), nil + } + justBefore() + }) + + it("should be the directory containing the file", func() { + Expect(parsedArgs.SourcePackageDir).ToNot(ContainSubstring("something.go")) + }) + }) + + when("when evaluating symlinks fails", func() { + it.Before(func() { + evaler = func(input string) (string, error) { + return "", errors.New("aww shucks") + } + justBefore() + }) + + it("should return an error with a useful message", func() { + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal(fmt.Sprintf("No such file/directory/package [%s]: aww shucks", path.Join(workingDir, "my/my5package")))) + }) + }) + + when("when the file stat cannot be read", func() { + it.Before(func() { + stater = func(_ string) (os.FileInfo, error) { + return fakeFileInfo("", false), errors.New("submarine-shoutout") + } + justBefore() + }) + + it("should return an error with a useful message", func() { + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal(fmt.Sprintf("No such file/directory/package [%s]: submarine-shoutout", path.Join(workingDir, "my/my5package")))) + }) }) }) + }) + + when("when the output dir contains characters inappropriate for a package name", func() { + it.Before(func() { + args = []string{"counterfeiter", "@my-special-package[]{}", "MySpecialInterface"} + justBefore() + }) + + it("should choose a valid package name", func() { + Expect(parsedArgs.DestinationPackageName).To(Equal("myspecialpackagefakes")) + }) + }) - when("when evaluating symlinks fails", func() { + when("when three arguments are provided", func() { + when("and the third one is '-'", func() { it.Before(func() { - evaler = func(input string) (string, error) { - return "", errors.New("aww shucks") - } + args = []string{"counterfeiter", "my/mypackage", "MySpecialInterface", "-"} justBefore() }) - it("should return an error with a useful message", func() { - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(Equal(fmt.Sprintf("No such file/directory/package [%s]: aww shucks", path.Join(workingDir, "my/my5package")))) + it("treats the second argument as the interface to counterfeit", func() { + Expect(parsedArgs.InterfaceName).To(Equal("MySpecialInterface")) + }) + + it("provides a name for the fake implementing the interface", func() { + Expect(parsedArgs.FakeImplName).To(Equal("FakeMySpecialInterface")) + }) + + it("indicates that the fake should be printed to stdout", func() { + Expect(parsedArgs.PrintToStdOut).To(BeTrue()) + }) + + it("snake cases the filename for the output directory", func() { + Expect(parsedArgs.OutputPath).To(Equal( + filepath.Join( + parsedArgs.SourcePackageDir, + "mypackagefakes", + "fake_my_special_interface.go", + ), + )) + }) + + when("the source directory", func() { + it("should be an absolute path", func() { + Expect(filepath.IsAbs(parsedArgs.SourcePackageDir)).To(BeTrue()) + }) + + when("when the first arg is a path to a file", func() { + it.Before(func() { + stater = func(filename string) (os.FileInfo, error) { + return fakeFileInfo(filename, false), nil + } + }) + + it("should be the directory containing the file", func() { + Expect(parsedArgs.SourcePackageDir).ToNot(ContainSubstring("something.go")) + }) + }) }) }) - when("when the file stat cannot be read", func() { + when("and the third one is some random input", func() { it.Before(func() { - stater = func(_ string) (os.FileInfo, error) { - return fakeFileInfo("", false), errors.New("submarine-shoutout") - } + args = []string{"counterfeiter", "my/mypackage", "MySpecialInterface", "WHOOPS"} justBefore() }) - it("should return an error with a useful message", func() { - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(Equal(fmt.Sprintf("No such file/directory/package [%s]: submarine-shoutout", path.Join(workingDir, "my/my5package")))) + it("indicates to not print to stdout", func() { + Expect(parsedArgs.PrintToStdOut).To(BeFalse()) }) }) }) - }) - when("when the output dir contains characters inappropriate for a package name", func() { - it.Before(func() { - args = []string{"counterfeiter", "@my-special-package[]{}", "MySpecialInterface"} - justBefore() - }) + when("when the output dir contains underscores in package name", func() { + it.Before(func() { + args = []string{"counterfeiter", "fake_command_runner", "MySpecialInterface"} + justBefore() + }) - it("should choose a valid package name", func() { - Expect(parsedArgs.DestinationPackageName).To(Equal("myspecialpackagefakes")) + it("should ensure underscores are in the package name", func() { + Expect(parsedArgs.DestinationPackageName).To(Equal("fake_command_runnerfakes")) + }) }) - }) - when("when three arguments are provided", func() { - when("and the third one is '-'", func() { + when("when '-header' is used", func() { it.Before(func() { - args = []string{"counterfeiter", "my/mypackage", "MySpecialInterface", "-"} + args = []string{"counterfeiter", "-header", "some/header/file", "some.interface"} justBefore() }) - it("treats the second argument as the interface to counterfeit", func() { - Expect(parsedArgs.InterfaceName).To(Equal("MySpecialInterface")) + it("sets the HeaderFile attribute on the parsedArgs struct", func() { + Expect(parsedArgs.HeaderFile).To(Equal("some/header/file")) + Expect(err).NotTo(HaveOccurred()) }) + }) + }) - it("provides a name for the fake implementing the interface", func() { - Expect(parsedArgs.FakeImplName).To(Equal("FakeMySpecialInterface")) + when("in generate mode", func() { + var generateArgs arguments.GenerateArgs + + it.Before(func() { + generateArgs = arguments.GenerateArgs{} + }) + + justBefore := func() { + parsedArgs, err = arguments.New(args, workingDir, &generateArgs, evaler, stater) + } + + when("generate was called with -o", func() { + it.Before(func() { + generateArgs.OutputPath = "generate-output" }) - it("indicates that the fake should be printed to stdout", func() { - Expect(parsedArgs.PrintToStdOut).To(BeTrue()) + when("the invocation specified -o also", func() { + it.Before(func() { + args = []string{"counterfeiter", "-o", "output", "someonesinterfaces.AnInterface"} + justBefore() + }) + + it("chooses the invocation's output path", func() { + Expect(parsedArgs.OutputPath).To(Equal( + filepath.Join( + workingDir, + "output", + "fake_an_interface.go", + ), + )) + }) }) - it("snake cases the filename for the output directory", func() { - Expect(parsedArgs.OutputPath).To(Equal( - filepath.Join( - parsedArgs.SourcePackageDir, - "mypackagefakes", - "fake_my_special_interface.go", - ), - )) + when("the invocation did not specify -o", func() { + it.Before(func() { + args = []string{"counterfeiter", "someonesinterfaces.AnInterface"} + justBefore() + }) + + it("chooses the generate call's output path", func() { + Expect(parsedArgs.OutputPath).To(Equal( + filepath.Join( + workingDir, + "generate-output", + "fake_an_interface.go", + ), + )) + }) }) + }) - when("the source directory", func() { - it("should be an absolute path", func() { - Expect(filepath.IsAbs(parsedArgs.SourcePackageDir)).To(BeTrue()) + when("generate was called with -q", func() { + it.Before(func() { + generateArgs.Quiet = true + }) + + when("the invocation specified -q also", func() { + it.Before(func() { + args = []string{"counterfeiter", "-q", "someonesinterfaces.AnInterface"} + justBefore() }) - when("when the first arg is a path to a file", func() { - it.Before(func() { - stater = func(filename string) (os.FileInfo, error) { - return fakeFileInfo(filename, false), nil - } - }) + it("is quiet", func() { + Expect(parsedArgs.Quiet).To(BeTrue()) + }) + }) - it("should be the directory containing the file", func() { - Expect(parsedArgs.SourcePackageDir).ToNot(ContainSubstring("something.go")) - }) + when("the invocation did not specify -q", func() { + it.Before(func() { + args = []string{"counterfeiter", "someonesinterfaces.AnInterface"} + justBefore() + }) + + it("is quiet", func() { + Expect(parsedArgs.Quiet).To(BeTrue()) }) }) }) - when("and the third one is some random input", func() { + when("generate was called with -header", func() { it.Before(func() { - args = []string{"counterfeiter", "my/mypackage", "MySpecialInterface", "WHOOPS"} - justBefore() + generateArgs.Header = "generate-header" }) - it("indicates to not print to stdout", func() { - Expect(parsedArgs.PrintToStdOut).To(BeFalse()) + when("the invocation specified -header also", func() { + it.Before(func() { + args = []string{"counterfeiter", "-header", "header", "someonesinterfaces.AnInterface"} + justBefore() + }) + + it("chooses the invocation's header", func() { + Expect(parsedArgs.HeaderFile).To(Equal("header")) + }) }) - }) - }) - when("when the output dir contains underscores in package name", func() { - it.Before(func() { - args = []string{"counterfeiter", "fake_command_runner", "MySpecialInterface"} - justBefore() - }) + when("the invocation did not specify -header", func() { + it.Before(func() { + args = []string{"counterfeiter", "someonesinterfaces.AnInterface"} + justBefore() + }) - it("should ensure underscores are in the package name", func() { - Expect(parsedArgs.DestinationPackageName).To(Equal("fake_command_runnerfakes")) + it("chooses the generate call's header", func() { + Expect(parsedArgs.HeaderFile).To(Equal("generate-header")) + }) + }) }) - }) - when("when '-header' is used", func() { - it.Before(func() { - args = []string{"counterfeiter", "-header", "some/header/file", "some.interface"} - justBefore() - }) + when("generate was called with -fake-name-template", func() { + it.Before(func() { + generateArgs.FakeNameTemplate, err = template.New("test").Parse("The{{.TargetName}}Imposter") + Expect(err).NotTo(HaveOccurred()) + }) - it("sets the HeaderFile attriburte on the parsedArgs struct", func() { - Expect(parsedArgs.HeaderFile).To(Equal("some/header/file")) - Expect(err).NotTo(HaveOccurred()) + when("the invocation specified -fake-name", func() { + it.Before(func() { + args = []string{"counterfeiter", "-fake-name", "FakestFake", "someonesinterfaces.AnInterface"} + justBefore() + }) + + it("chooses the invocation's fake name", func() { + Expect(parsedArgs.FakeImplName).To(Equal("FakestFake")) + }) + }) + + when("the invocation did not specify -fake-name", func() { + it.Before(func() { + args = []string{"counterfeiter", "someonesinterfaces.AnInterface"} + justBefore() + }) + + it("uses the fake-name-template to generate a fake name", func() { + Expect(parsedArgs.FakeImplName).To(Equal("TheAnInterfaceImposter")) + }) + }) + + when("the template is invalid", func() { + it.Before(func() { + generateArgs.FakeNameTemplate, err = template.New("test").Parse("{{.NoSuchField}}") + Expect(err).NotTo(HaveOccurred()) + + args = []string{"counterfeiter", "someonesinterfaces.AnInterface"} + justBefore() + }) + + it("errors", func() { + Expect(err).To(MatchError(ContainSubstring("fake-name-template"))) + }) + }) }) }) } diff --git a/fixtures/generate_defaults/fakes/not_the_real_song.go b/fixtures/generate_defaults/fakes/not_the_real_song.go new file mode 100644 index 0000000..1b3f9e0 --- /dev/null +++ b/fixtures/generate_defaults/fakes/not_the_real_song.go @@ -0,0 +1,102 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package fakes + +import ( + "sync" + + "github.com/maxbrunsfeld/counterfeiter/v6/fixtures/generate_defaults" +) + +type NotTheRealSong struct { + SongStub func() string + songMutex sync.RWMutex + songArgsForCall []struct { + } + songReturns struct { + result1 string + } + songReturnsOnCall map[int]struct { + result1 string + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *NotTheRealSong) Song() string { + fake.songMutex.Lock() + ret, specificReturn := fake.songReturnsOnCall[len(fake.songArgsForCall)] + fake.songArgsForCall = append(fake.songArgsForCall, struct { + }{}) + stub := fake.SongStub + fakeReturns := fake.songReturns + fake.recordInvocation("Song", []interface{}{}) + fake.songMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *NotTheRealSong) SongCallCount() int { + fake.songMutex.RLock() + defer fake.songMutex.RUnlock() + return len(fake.songArgsForCall) +} + +func (fake *NotTheRealSong) SongCalls(stub func() string) { + fake.songMutex.Lock() + defer fake.songMutex.Unlock() + fake.SongStub = stub +} + +func (fake *NotTheRealSong) SongReturns(result1 string) { + fake.songMutex.Lock() + defer fake.songMutex.Unlock() + fake.SongStub = nil + fake.songReturns = struct { + result1 string + }{result1} +} + +func (fake *NotTheRealSong) SongReturnsOnCall(i int, result1 string) { + fake.songMutex.Lock() + defer fake.songMutex.Unlock() + fake.SongStub = nil + if fake.songReturnsOnCall == nil { + fake.songReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.songReturnsOnCall[i] = struct { + result1 string + }{result1} +} + +func (fake *NotTheRealSong) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.songMutex.RLock() + defer fake.songMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *NotTheRealSong) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ generate_defaults.Song = new(NotTheRealSong) diff --git a/fixtures/generate_defaults/fakes/the_sing_imposter.go b/fixtures/generate_defaults/fakes/the_sing_imposter.go new file mode 100644 index 0000000..e3f455b --- /dev/null +++ b/fixtures/generate_defaults/fakes/the_sing_imposter.go @@ -0,0 +1,102 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package fakes + +import ( + "sync" + + "github.com/maxbrunsfeld/counterfeiter/v6/fixtures/generate_defaults" +) + +type TheSingImposter struct { + SingStub func() string + singMutex sync.RWMutex + singArgsForCall []struct { + } + singReturns struct { + result1 string + } + singReturnsOnCall map[int]struct { + result1 string + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *TheSingImposter) Sing() string { + fake.singMutex.Lock() + ret, specificReturn := fake.singReturnsOnCall[len(fake.singArgsForCall)] + fake.singArgsForCall = append(fake.singArgsForCall, struct { + }{}) + stub := fake.SingStub + fakeReturns := fake.singReturns + fake.recordInvocation("Sing", []interface{}{}) + fake.singMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *TheSingImposter) SingCallCount() int { + fake.singMutex.RLock() + defer fake.singMutex.RUnlock() + return len(fake.singArgsForCall) +} + +func (fake *TheSingImposter) SingCalls(stub func() string) { + fake.singMutex.Lock() + defer fake.singMutex.Unlock() + fake.SingStub = stub +} + +func (fake *TheSingImposter) SingReturns(result1 string) { + fake.singMutex.Lock() + defer fake.singMutex.Unlock() + fake.SingStub = nil + fake.singReturns = struct { + result1 string + }{result1} +} + +func (fake *TheSingImposter) SingReturnsOnCall(i int, result1 string) { + fake.singMutex.Lock() + defer fake.singMutex.Unlock() + fake.SingStub = nil + if fake.singReturnsOnCall == nil { + fake.singReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.singReturnsOnCall[i] = struct { + result1 string + }{result1} +} + +func (fake *TheSingImposter) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.singMutex.RLock() + defer fake.singMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *TheSingImposter) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ generate_defaults.Sing = new(TheSingImposter) diff --git a/fixtures/generate_defaults/other-fakes/ponger.go b/fixtures/generate_defaults/other-fakes/ponger.go new file mode 100644 index 0000000..aad794d --- /dev/null +++ b/fixtures/generate_defaults/other-fakes/ponger.go @@ -0,0 +1,102 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package otherfakes + +import ( + "sync" + + "github.com/maxbrunsfeld/counterfeiter/v6/fixtures/generate_defaults" +) + +type Ponger struct { + PongStub func() string + pongMutex sync.RWMutex + pongArgsForCall []struct { + } + pongReturns struct { + result1 string + } + pongReturnsOnCall map[int]struct { + result1 string + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *Ponger) Pong() string { + fake.pongMutex.Lock() + ret, specificReturn := fake.pongReturnsOnCall[len(fake.pongArgsForCall)] + fake.pongArgsForCall = append(fake.pongArgsForCall, struct { + }{}) + stub := fake.PongStub + fakeReturns := fake.pongReturns + fake.recordInvocation("Pong", []interface{}{}) + fake.pongMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *Ponger) PongCallCount() int { + fake.pongMutex.RLock() + defer fake.pongMutex.RUnlock() + return len(fake.pongArgsForCall) +} + +func (fake *Ponger) PongCalls(stub func() string) { + fake.pongMutex.Lock() + defer fake.pongMutex.Unlock() + fake.PongStub = stub +} + +func (fake *Ponger) PongReturns(result1 string) { + fake.pongMutex.Lock() + defer fake.pongMutex.Unlock() + fake.PongStub = nil + fake.pongReturns = struct { + result1 string + }{result1} +} + +func (fake *Ponger) PongReturnsOnCall(i int, result1 string) { + fake.pongMutex.Lock() + defer fake.pongMutex.Unlock() + fake.PongStub = nil + if fake.pongReturnsOnCall == nil { + fake.pongReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.pongReturnsOnCall[i] = struct { + result1 string + }{result1} +} + +func (fake *Ponger) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.pongMutex.RLock() + defer fake.pongMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *Ponger) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ generate_defaults.Pong = new(Ponger) diff --git a/fixtures/generate_defaults/other-fakes/the_sang_imposter.go b/fixtures/generate_defaults/other-fakes/the_sang_imposter.go new file mode 100644 index 0000000..660f7de --- /dev/null +++ b/fixtures/generate_defaults/other-fakes/the_sang_imposter.go @@ -0,0 +1,102 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package otherfakes + +import ( + "sync" + + "github.com/maxbrunsfeld/counterfeiter/v6/fixtures/generate_defaults" +) + +type TheSangImposter struct { + SangStub func() string + sangMutex sync.RWMutex + sangArgsForCall []struct { + } + sangReturns struct { + result1 string + } + sangReturnsOnCall map[int]struct { + result1 string + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *TheSangImposter) Sang() string { + fake.sangMutex.Lock() + ret, specificReturn := fake.sangReturnsOnCall[len(fake.sangArgsForCall)] + fake.sangArgsForCall = append(fake.sangArgsForCall, struct { + }{}) + stub := fake.SangStub + fakeReturns := fake.sangReturns + fake.recordInvocation("Sang", []interface{}{}) + fake.sangMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *TheSangImposter) SangCallCount() int { + fake.sangMutex.RLock() + defer fake.sangMutex.RUnlock() + return len(fake.sangArgsForCall) +} + +func (fake *TheSangImposter) SangCalls(stub func() string) { + fake.sangMutex.Lock() + defer fake.sangMutex.Unlock() + fake.SangStub = stub +} + +func (fake *TheSangImposter) SangReturns(result1 string) { + fake.sangMutex.Lock() + defer fake.sangMutex.Unlock() + fake.SangStub = nil + fake.sangReturns = struct { + result1 string + }{result1} +} + +func (fake *TheSangImposter) SangReturnsOnCall(i int, result1 string) { + fake.sangMutex.Lock() + defer fake.sangMutex.Unlock() + fake.SangStub = nil + if fake.sangReturnsOnCall == nil { + fake.sangReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.sangReturnsOnCall[i] = struct { + result1 string + }{result1} +} + +func (fake *TheSangImposter) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.sangMutex.RLock() + defer fake.sangMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *TheSangImposter) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ generate_defaults.Sang = new(TheSangImposter) diff --git a/fixtures/generate_defaults/output_and_fake_name.go b/fixtures/generate_defaults/output_and_fake_name.go new file mode 100644 index 0000000..e50f156 --- /dev/null +++ b/fixtures/generate_defaults/output_and_fake_name.go @@ -0,0 +1,23 @@ +package generate_defaults + +//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate -o fakes -fake-name-template "The{{.TargetName}}Imposter" + +//counterfeiter:generate . Sing +type Sing interface { + Sing() string +} + +//counterfeiter:generate -o other-fakes . Sang +type Sang interface { + Sang() string +} + +//counterfeiter:generate -fake-name NotTheRealSong . Song +type Song interface { + Song() string +} + +//counterfeiter:generate -o other-fakes -fake-name Ponger . Pong +type Pong interface { + Pong() string +} diff --git a/main.go b/main.go index 0efb3f8..bdfe2d9 100644 --- a/main.go +++ b/main.go @@ -61,35 +61,26 @@ func run() error { cache = &generator.Cache{} headerReader = &generator.CachedFileReader{} } - var invocations []command.Invocation - var args *arguments.ParsedArguments - args, _ = arguments.New(os.Args, cwd, filepath.EvalSymlinks, os.Stat) - generateMode := false - if args != nil { - generateMode = args.GenerateMode + + generateMode, generateArgs, err := arguments.ParseGenerateMode(os.Args) + if err != nil { + return err } + if !generateMode && shouldPrintGenerateWarning() { fmt.Printf("\nWARNING: Invoking counterfeiter multiple times from \"go generate\" is slow.\nConsider using counterfeiter:generate directives to speed things up.\nSee https://github.com/maxbrunsfeld/counterfeiter#step-2b---add-counterfeitergenerate-directives for more information.\nSet the \"COUNTERFEITER_NO_GENERATE_WARNING\" environment variable to suppress this message.\n\n") } - invocations, err = command.Detect(cwd, os.Args, generateMode) + invocations, err := command.Detect(cwd, os.Args, generateMode) if err != nil { return err } for i := range invocations { - a, err := arguments.New(invocations[i].Args, cwd, filepath.EvalSymlinks, os.Stat) + a, err := arguments.New(invocations[i].Args, cwd, generateArgs, filepath.EvalSymlinks, os.Stat) if err != nil { return err } - // If the '//counterfeiter:generate ...' line does not have a '-header' - // flag, we use the one from the "global" - // '//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate -header /some/header.txt' - // line (which defaults to none). By doing so, we can configure the header - // once per package, which is probably the most common case for adding - // licence headers (i.e. all the fakes will have the same licence headers). - a.HeaderFile = or(a.HeaderFile, args.HeaderFile) - err = generate(cwd, a, cache, headerReader) if err != nil { return err @@ -98,15 +89,6 @@ func run() error { return nil } -func or(opts ...string) string { - for _, s := range opts { - if s != "" { - return s - } - } - return "" -} - func isDebug() bool { return os.Getenv("COUNTERFEITER_DEBUG") != "" }