From 7207be44a11c9d07bd2cf80ee0f89509f18e1a25 Mon Sep 17 00:00:00 2001 From: "M. J. Fromberger" Date: Sat, 13 Jan 2024 16:22:17 -0800 Subject: [PATCH] Test that panicking in the Tx callback properly calls Cancel. --- atomicfile.go | 6 +++--- atomicfile_test.go | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/atomicfile.go b/atomicfile.go index bac60b9..32160d2 100644 --- a/atomicfile.go +++ b/atomicfile.go @@ -39,9 +39,9 @@ func New(target string, mode os.FileMode) (*File, error) { }, nil } -// Tx calls f with a file constructed by New. If f reports an error, the file -// is automatically cancelled and Tx returns the error from f. Otherwise, Tx -// returns the error from calling Close on the file. +// Tx calls f with a file constructed by New. If f reports an error or panics, +// the file is automatically cancelled and Tx returns the error from f. +// Otherwise, Tx returns the error from calling Close on the file. func Tx(target string, mode os.FileMode, f func(*File) error) error { tmp, err := New(target, mode) if err != nil { diff --git a/atomicfile_test.go b/atomicfile_test.go index 7adffdd..8eff238 100644 --- a/atomicfile_test.go +++ b/atomicfile_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/creachadair/atomicfile" + "github.com/creachadair/mtest" ) var ( @@ -195,6 +196,28 @@ func TestTx(t *testing.T) { } checkFile(t, path, 0604, text) }) + + t.Run("Panic", func(t *testing.T) { + path := filepath.Join(t.TempDir(), "knucklebones.txt") + v := mtest.MustPanic(t, func() { + atomicfile.Tx(path, 0600, func(*atomicfile.File) error { + panic("ouchies") + }) + }) + + // Make sure we got the panic from the callback. + if s, ok := v.(string); !ok || s != "ouchies" { + t.Errorf("Unexpected panic: %v", v) + } + + // Make sure nothing was left in the output directory. + elts, err := os.ReadDir(filepath.Dir(path)) + if err != nil { + t.Fatalf("Reading output directory: %v", err) + } else if len(elts) != 0 { + t.Errorf("Unexpected output: %v", elts) + } + }) } func TestWrite(t *testing.T) {