diff --git a/printer/print.go b/printer/print.go index b90d03d9..44c31786 100644 --- a/printer/print.go +++ b/printer/print.go @@ -1,7 +1,7 @@ package printer import ( - "bufio" + "bytes" "fmt" "github.com/tinylib/msgp/gen" "github.com/tinylib/msgp/parse" @@ -9,7 +9,6 @@ import ( "golang.org/x/tools/imports" "io" "io/ioutil" - "os" "strings" ) @@ -21,98 +20,82 @@ func infof(s string, v ...interface{}) { // of elements to the given file name and canonical // package path. func PrintFile(file string, f *parse.FileSet, mode gen.Method) error { - err := generate(file, f, mode) + out, tests, err := generate(f, mode) if err != nil { return err } - infof(">>> Generated \"%s\"\n", file) - err = format(file) + + // we'll run goimports on the main file + // in another goroutine, and run it here + // for the test file. empirically, this + // takes about the same amount of time as + // doing them in serial when GOMAXPROCS=1, + // and faster otherwise. + res := goformat(file, out.Bytes()) + if tests != nil { + testfile := strings.TrimSuffix(file, ".go") + "_test.go" + err = format(testfile, tests.Bytes()) + if err != nil { + return err + } + infof(">>> Wrote and formatted \"%s\"\n", testfile) + } + err = <-res if err != nil { return err } - infof(">>> Formatted \"%s\"\n", file) infof(">>> Done.\n") return nil } -func format(file string) error { - data, err := ioutil.ReadFile(file) +func format(file string, data []byte) error { + out, err := imports.Process(file, data, nil) if err != nil { return err } - data, err = imports.Process(file, data, nil) - if err != nil { - return err - } - return ioutil.WriteFile(file, data, 0600) + return ioutil.WriteFile(file, out, 0600) } -func generate(file string, f *parse.FileSet, mode gen.Method) error { - if mode&^gen.Test == 0 { - return nil - } +func goformat(file string, data []byte) <-chan error { + out := make(chan error, 1) + go func(file string, data []byte, end chan error) { + end <- format(file, data) + infof(">>> Wrote and formatted \"%s\"\n", file) + }(file, data, out) + return out +} - outfile, err := os.Create(file) - if err != nil { - return err - } - defer outfile.Close() - outwr := bufio.NewWriter(outfile) - defer outwr.Flush() +func generate(f *parse.FileSet, mode gen.Method) (*bytes.Buffer, *bytes.Buffer, error) { + outbuf := bytes.NewBuffer(make([]byte, 0, 4096)) + writePkgHeader(outbuf, f.Package) + writeImportHeader(outbuf, "github.com/tinylib/msgp/msgp") - err = writePkgHeader(outwr, f.Package) - if err != nil { - return err - } - err = writeImportHeader(outwr, "github.com/tinylib/msgp/msgp") - if err != nil { - return err - } - - var testwr *bufio.Writer + var testbuf *bytes.Buffer + var testwr io.Writer if mode&gen.Test == gen.Test { - tfname := strings.TrimSuffix(file, ".go") + "_test.go" - testfile, err := os.Create(strings.TrimSuffix(file, ".go") + "_test.go") - if err != nil { - return err - } - infof(">>> Tests in \"%s\"\n", tfname) - defer testfile.Close() - testwr = bufio.NewWriter(testfile) - defer testwr.Flush() - err = writePkgHeader(testwr, f.Package) - if err != nil { - return err - } + testbuf = bytes.NewBuffer(make([]byte, 0, 4096)) + writePkgHeader(testbuf, f.Package) if mode&(gen.Encode|gen.Decode) != 0 { - err = writeImportHeader(testwr, "bytes", "github.com/tinylib/msgp/msgp", "testing") + writeImportHeader(testbuf, "bytes", "github.com/tinylib/msgp/msgp", "testing") } else { - err = writeImportHeader(testwr, "github.com/tinylib/msgp/msgp", "testing") + writeImportHeader(testbuf, "github.com/tinylib/msgp/msgp", "testing") } + testwr = testbuf } - return f.PrintTo(gen.NewPrinter(mode, outwr, testwr)) + return outbuf, testbuf, f.PrintTo(gen.NewPrinter(mode, outbuf, testwr)) } -func writePkgHeader(w io.Writer, name string) error { - _, err := fmt.Fprintln(w, "package", name, "\n") - if err != nil { - return err - } - _, err = io.WriteString(w, "// NOTE: THIS FILE WAS PRODUCED BY THE\n// MSGP CODE GENERATION TOOL (github.com/tinylib/msgp)\n// DO NOT EDIT\n\n") - return err +func writePkgHeader(b *bytes.Buffer, name string) { + b.WriteString("package ") + b.WriteString(name) + b.WriteByte('\n') + b.WriteString("// NOTE: THIS FILE WAS PRODUCED BY THE\n// MSGP CODE GENERATION TOOL (github.com/tinylib/msgp)\n// DO NOT EDIT\n\n") } -func writeImportHeader(w io.Writer, imports ...string) error { - _, err := io.WriteString(w, "import (\n") - if err != nil { - return err - } +func writeImportHeader(b *bytes.Buffer, imports ...string) { + b.WriteString("import (\n") for _, im := range imports { - _, err = io.WriteString(w, fmt.Sprintf("\t%q\n", im)) - if err != nil { - return err - } + fmt.Fprintf(b, "\t%q\n", im) } - _, err = io.WriteString(w, ")\n\n") - return err + b.WriteString(")\n\n") }