From b2e71562b59bdbb9e9ed1311772973ca0a30847d Mon Sep 17 00:00:00 2001 From: Nick Travers Date: Thu, 6 Oct 2022 10:59:19 -0700 Subject: [PATCH] tool: capture stdout / stderr Currently, the tools write directly to `os.{Stdout,Stderr}`, which complicates testing in the case where the command is run from same process running the tests (i.e. from a `testing.T` func). Adapt the tools to make use of the return values from `(*cobra.Command).{OutOrStdout,OutOrStderr}`. In testing scenarios, the tools can use `SetOut` and `SetErr` to pass in a writer that can intercept anything written to stdout / stderr. In productions scenarios, the tools will emit to the appropriate channel. Touches cockroachdb/cockroach#89095. --- tool/data_test.go | 8 ++------ tool/db.go | 24 ++++++++++++++++-------- tool/find.go | 23 ++++++++++++----------- tool/manifest.go | 14 ++++++++------ tool/sstable.go | 20 +++++++++++++------- tool/util.go | 7 ++----- tool/wal.go | 1 + 7 files changed, 54 insertions(+), 43 deletions(-) diff --git a/tool/data_test.go b/tool/data_test.go index d6c306fb05..ed289c0943 100644 --- a/tool/data_test.go +++ b/tool/data_test.go @@ -69,15 +69,10 @@ func runTests(t *testing.T, path string) { } var buf bytes.Buffer - stdout = &buf - stderr = &buf - var secs int64 timeNow = func() time.Time { secs++; return time.Unix(secs, 0) } defer func() { - stdout = os.Stdout - stderr = os.Stderr timeNow = time.Now }() @@ -121,7 +116,8 @@ func runTests(t *testing.T, path string) { c := &cobra.Command{} c.AddCommand(tool.Commands...) c.SetArgs(args) - c.SetOutput(&buf) + c.SetOut(&buf) + c.SetErr(&buf) if err := c.Execute(); err != nil { return err.Error() } diff --git a/tool/db.go b/tool/db.go index 4a21239d12..73b6a3ae61 100644 --- a/tool/db.go +++ b/tool/db.go @@ -278,19 +278,20 @@ func (d *dbT) openDB(dir string, openOptions ...openOption) (*pebble.DB, error) return pebble.Open(dir, &opts) } -func (d *dbT) closeDB(db *pebble.DB) { +func (d *dbT) closeDB(stdout io.Writer, db *pebble.DB) { if err := db.Close(); err != nil { fmt.Fprintf(stdout, "%s\n", err) } } func (d *dbT) runCheck(cmd *cobra.Command, args []string) { + stdout := cmd.OutOrStdout() db, err := d.openDB(args[0]) if err != nil { fmt.Fprintf(stdout, "%s\n", err) return } - defer d.closeDB(db) + defer d.closeDB(stdout, db) var stats pebble.CheckLevelsStats if err := db.CheckLevels(&stats); err != nil { @@ -310,12 +311,13 @@ func (n nonReadOnly) apply(opts *pebble.Options) { } func (d *dbT) runCheckpoint(cmd *cobra.Command, args []string) { + stdout := cmd.OutOrStdout() db, err := d.openDB(args[0], nonReadOnly{}) if err != nil { fmt.Fprintf(stdout, "%s\n", err) return } - defer d.closeDB(db) + defer d.closeDB(stdout, db) destDir := args[1] if err := db.Checkpoint(destDir); err != nil { @@ -324,12 +326,13 @@ func (d *dbT) runCheckpoint(cmd *cobra.Command, args []string) { } func (d *dbT) runGet(cmd *cobra.Command, args []string) { + stdout := cmd.OutOrStdout() db, err := d.openDB(args[0]) if err != nil { fmt.Fprintf(stdout, "%s\n", err) return } - defer d.closeDB(db) + defer d.closeDB(stdout, db) var k key if err := k.Set(args[1]); err != nil { fmt.Fprintf(stdout, "%s\n", err) @@ -352,23 +355,25 @@ func (d *dbT) runGet(cmd *cobra.Command, args []string) { } func (d *dbT) runLSM(cmd *cobra.Command, args []string) { + stdout := cmd.OutOrStdout() db, err := d.openDB(args[0]) if err != nil { fmt.Fprintf(stdout, "%s\n", err) return } - defer d.closeDB(db) + defer d.closeDB(stdout, db) fmt.Fprintf(stdout, "%s", db.Metrics()) } func (d *dbT) runScan(cmd *cobra.Command, args []string) { + stdout := cmd.OutOrStdout() db, err := d.openDB(args[0]) if err != nil { fmt.Fprintf(stdout, "%s\n", err) return } - defer d.closeDB(db) + defer d.closeDB(stdout, db) // Update the internal formatter if this comparator has one specified. if d.opts.Comparer != nil { @@ -417,12 +422,13 @@ func (d *dbT) runScan(cmd *cobra.Command, args []string) { } func (d *dbT) runSpace(cmd *cobra.Command, args []string) { + stdout, stderr := cmd.OutOrStdout(), cmd.OutOrStderr() db, err := d.openDB(args[0]) if err != nil { fmt.Fprintf(stderr, "%s\n", err) return } - defer d.closeDB(db) + defer d.closeDB(stdout, db) bytes, err := db.EstimateDiskUsage(d.start, d.end) if err != nil { @@ -433,6 +439,7 @@ func (d *dbT) runSpace(cmd *cobra.Command, args []string) { } func (d *dbT) runProperties(cmd *cobra.Command, args []string) { + stdout, stderr := cmd.OutOrStdout(), cmd.OutOrStderr() dirname := args[0] err := func() error { desc, err := pebble.Peek(dirname, d.opts.FS) @@ -557,12 +564,13 @@ func (d *dbT) runProperties(cmd *cobra.Command, args []string) { } func (d *dbT) runSet(cmd *cobra.Command, args []string) { + stdout := cmd.OutOrStdout() db, err := d.openDB(args[0], nonReadOnly{}) if err != nil { fmt.Fprintf(stdout, "%s\n", err) return } - defer d.closeDB(db) + defer d.closeDB(stdout, db) var k, v key if err := k.Set(args[1]); err != nil { fmt.Fprintf(stdout, "%s\n", err) diff --git a/tool/find.go b/tool/find.go index c2c2419bc5..2563edd234 100644 --- a/tool/find.go +++ b/tool/find.go @@ -98,17 +98,18 @@ provenance of the sstables (flushed, ingested, compacted). } func (f *findT) run(cmd *cobra.Command, args []string) { + stdout, stderr := cmd.OutOrStdout(), cmd.OutOrStderr() var key key if err := key.Set(args[1]); err != nil { fmt.Fprintf(stdout, "%s\n", err) return } - if err := f.findFiles(args[0]); err != nil { + if err := f.findFiles(stdout, stderr, args[0]); err != nil { fmt.Fprintf(stdout, "%s\n", err) return } - f.readManifests() + f.readManifests(stdout) f.opts.Comparer = f.comparers[f.comparerName] if f.opts.Comparer == nil { @@ -118,7 +119,7 @@ func (f *findT) run(cmd *cobra.Command, args []string) { f.fmtKey.setForComparer(f.opts.Comparer.Name, f.comparers) f.fmtValue.setForComparer(f.opts.Comparer.Name, f.comparers) - refs := f.search(key) + refs := f.search(stdout, key) var lastFileNum base.FileNum for i := range refs { r := &refs[i] @@ -140,7 +141,7 @@ func (f *findT) run(cmd *cobra.Command, args []string) { } // Find all of the manifests, logs, and tables in the specified directory. -func (f *findT) findFiles(dir string) error { +func (f *findT) findFiles(stdout, stderr io.Writer, dir string) error { f.files = make(map[base.FileNum]string) f.editRefs = make(map[base.FileNum][]int) f.logs = nil @@ -152,7 +153,7 @@ func (f *findT) findFiles(dir string) error { return err } - walk(f.opts.FS, dir, func(path string) { + walk(stderr, f.opts.FS, dir, func(path string) { ft, fileNum, ok := base.ParseFilename(f.opts.FS, path) if !ok { return @@ -191,7 +192,7 @@ func (f *findT) findFiles(dir string) error { // Read the manifests and populate the editRefs map which is used to determine // the provenance and metadata of tables. -func (f *findT) readManifests() { +func (f *findT) readManifests(stdout io.Writer) { for _, fileNum := range f.manifests { func() { path := f.files[fileNum] @@ -255,9 +256,9 @@ func (f *findT) readManifests() { } // Search the logs and sstables for references to the specified key. -func (f *findT) search(key []byte) []findRef { - refs := f.searchLogs(key, nil) - refs = f.searchTables(key, refs) +func (f *findT) search(stdout io.Writer, key []byte) []findRef { + refs := f.searchLogs(stdout, key, nil) + refs = f.searchTables(stdout, key, refs) // For a given file (log or table) the references are already in the correct // order. We simply want to order the references by fileNum using a stable @@ -280,7 +281,7 @@ func (f *findT) search(key []byte) []findRef { } // Search the logs for references to the specified key. -func (f *findT) searchLogs(searchKey []byte, refs []findRef) []findRef { +func (f *findT) searchLogs(stdout io.Writer, searchKey []byte, refs []findRef) []findRef { cmp := f.opts.Comparer.Compare for _, fileNum := range f.logs { _ = func() (err error) { @@ -374,7 +375,7 @@ func (f *findT) searchLogs(searchKey []byte, refs []findRef) []findRef { } // Search the tables for references to the specified key. -func (f *findT) searchTables(searchKey []byte, refs []findRef) []findRef { +func (f *findT) searchTables(stdout io.Writer, searchKey []byte, refs []findRef) []findRef { cache := pebble.NewCache(128 << 20 /* 128 MB */) defer cache.Unref() diff --git a/tool/manifest.go b/tool/manifest.go index bc74e2334a..fd00862df6 100644 --- a/tool/manifest.go +++ b/tool/manifest.go @@ -96,7 +96,7 @@ Check the contents of the MANIFEST files. return m } -func (m *manifestT) printLevels(v *manifest.Version) { +func (m *manifestT) printLevels(stdout io.Writer, v *manifest.Version) { for level := range v.Levels { if level == 0 && len(v.L0SublevelFiles) > 0 && !v.Levels[level].Empty() { for sublevel := len(v.L0SublevelFiles) - 1; sublevel >= 0; sublevel-- { @@ -122,6 +122,7 @@ func (m *manifestT) printLevels(v *manifest.Version) { } func (m *manifestT) runDump(cmd *cobra.Command, args []string) { + stdout, stderr := cmd.OutOrStdout(), cmd.OutOrStderr() for _, arg := range args { func() { f, err := m.opts.FS.Open(arg) @@ -226,7 +227,7 @@ func (m *manifestT) runDump(cmd *cobra.Command, args []string) { fmt.Fprintf(stdout, "%s\n", err) return } - m.printLevels(v) + m.printLevels(stdout, v) } }() } @@ -234,14 +235,14 @@ func (m *manifestT) runDump(cmd *cobra.Command, args []string) { func (m *manifestT) runSummarize(cmd *cobra.Command, args []string) { for _, arg := range args { - err := m.runSummarizeOne(arg) + err := m.runSummarizeOne(cmd.OutOrStdout(), arg) if err != nil { - fmt.Fprintf(stderr, "%s\n", err) + fmt.Fprintf(cmd.OutOrStderr(), "%s\n", err) } } } -func (m *manifestT) runSummarizeOne(arg string) error { +func (m *manifestT) runSummarizeOne(stdout io.Writer, arg string) error { f, err := m.opts.FS.Open(arg) if err != nil { return err @@ -428,6 +429,7 @@ func (m *manifestT) runSummarizeOne(arg string) error { } func (m *manifestT) runCheck(cmd *cobra.Command, args []string) { + stdout, stderr := cmd.OutOrStdout(), cmd.OutOrStderr() ok := true for _, arg := range args { func() { @@ -498,7 +500,7 @@ func (m *manifestT) runCheck(cmd *cobra.Command, args []string) { fmt.Fprintf(stdout, "%s: offset: %d err: %s\n", arg, offset, err) fmt.Fprintf(stdout, "Version state before failed Apply\n") - m.printLevels(v) + m.printLevels(stdout, v) fmt.Fprintf(stdout, "Version edit that failed\n") for df := range ve.DeletedFiles { fmt.Fprintf(stdout, " deleted: L%d %s\n", df.Level, df.FileNum) diff --git a/tool/sstable.go b/tool/sstable.go index f15383c187..3e064be7c9 100644 --- a/tool/sstable.go +++ b/tool/sstable.go @@ -7,6 +7,7 @@ package tool import ( "bytes" "fmt" + "io" "os" "path/filepath" "sort" @@ -151,7 +152,8 @@ func (s *sstableT) newReader(f vfs.File) (*sstable.Reader, error) { } func (s *sstableT) runCheck(cmd *cobra.Command, args []string) { - s.foreachSstable(args, func(arg string) { + stdout, stderr := cmd.OutOrStdout(), cmd.OutOrStderr() + s.foreachSstable(stderr, args, func(arg string) { f, err := s.opts.FS.Open(arg) if err != nil { fmt.Fprintf(stderr, "%s\n", err) @@ -227,7 +229,8 @@ func (s *sstableT) runCheck(cmd *cobra.Command, args []string) { } func (s *sstableT) runLayout(cmd *cobra.Command, args []string) { - s.foreachSstable(args, func(arg string) { + stdout, stderr := cmd.OutOrStdout(), cmd.OutOrStderr() + s.foreachSstable(stderr, args, func(arg string) { f, err := s.opts.FS.Open(arg) if err != nil { fmt.Fprintf(stderr, "%s\n", err) @@ -263,7 +266,8 @@ func (s *sstableT) runLayout(cmd *cobra.Command, args []string) { } func (s *sstableT) runProperties(cmd *cobra.Command, args []string) { - s.foreachSstable(args, func(arg string) { + stdout, stderr := cmd.OutOrStdout(), cmd.OutOrStderr() + s.foreachSstable(stderr, args, func(arg string) { f, err := s.opts.FS.Open(arg) if err != nil { fmt.Fprintf(stderr, "%s\n", err) @@ -355,7 +359,8 @@ func (s *sstableT) runProperties(cmd *cobra.Command, args []string) { } func (s *sstableT) runScan(cmd *cobra.Command, args []string) { - s.foreachSstable(args, func(arg string) { + stdout, stderr := cmd.OutOrStdout(), cmd.OutOrStderr() + s.foreachSstable(stderr, args, func(arg string) { f, err := s.opts.FS.Open(arg) if err != nil { fmt.Fprintf(stderr, "%s\n", err) @@ -512,7 +517,8 @@ func (s *sstableT) runScan(cmd *cobra.Command, args []string) { } func (s *sstableT) runSpace(cmd *cobra.Command, args []string) { - s.foreachSstable(args, func(arg string) { + stdout, stderr := cmd.OutOrStdout(), cmd.OutOrStderr() + s.foreachSstable(stderr, args, func(arg string) { f, err := s.opts.FS.Open(arg) if err != nil { fmt.Fprintf(stderr, "%s\n", err) @@ -534,7 +540,7 @@ func (s *sstableT) runSpace(cmd *cobra.Command, args []string) { }) } -func (s *sstableT) foreachSstable(args []string, fn func(arg string)) { +func (s *sstableT) foreachSstable(stderr io.Writer, args []string, fn func(arg string)) { // Loop over args, invoking fn for each file. Each directory is recursively // listed and fn is invoked on any file with an .sst or .ldb suffix. for _, arg := range args { @@ -543,7 +549,7 @@ func (s *sstableT) foreachSstable(args []string, fn func(arg string)) { fn(arg) continue } - walk(s.opts.FS, arg, func(path string) { + walk(stderr, s.opts.FS, arg, func(path string) { switch filepath.Ext(path) { case ".sst", ".ldb": fn(path) diff --git a/tool/util.go b/tool/util.go index fbafed0717..e315424a25 100644 --- a/tool/util.go +++ b/tool/util.go @@ -8,7 +8,6 @@ import ( "encoding/hex" "fmt" "io" - "os" "sort" "strings" "time" @@ -20,8 +19,6 @@ import ( "github.com/cockroachdb/pebble/vfs" ) -var stdout = io.Writer(os.Stdout) -var stderr = io.Writer(os.Stderr) var timeNow = time.Now type key []byte @@ -305,7 +302,7 @@ func formatSpan(w io.Writer, fmtKey keyFormatter, fmtValue valueFormatter, s *ke } } -func walk(fs vfs.FS, dir string, fn func(path string)) { +func walk(stderr io.Writer, fs vfs.FS, dir string, fn func(path string)) { paths, err := fs.List(dir) if err != nil { fmt.Fprintf(stderr, "%s: %v\n", dir, err) @@ -320,7 +317,7 @@ func walk(fs vfs.FS, dir string, fn func(path string)) { continue } if info.IsDir() { - walk(fs, path, fn) + walk(stderr, fs, path, fn) } else { fn(path) } diff --git a/tool/wal.go b/tool/wal.go index d701174bba..c2e44c07ab 100644 --- a/tool/wal.go +++ b/tool/wal.go @@ -65,6 +65,7 @@ Print the contents of the WAL files. } func (w *walT) runDump(cmd *cobra.Command, args []string) { + stdout, stderr := cmd.OutOrStdout(), cmd.OutOrStderr() w.fmtKey.setForComparer(w.defaultComparer, w.comparers) w.fmtValue.setForComparer(w.defaultComparer, w.comparers)