diff --git a/README.md b/README.md index 856e358e..87c438c9 100644 --- a/README.md +++ b/README.md @@ -145,6 +145,9 @@ func main() { k.Print() }) + // To stop a file watcher, call: + // f.Unwatch() + // Block forever (and manually make a change to mock/mock.json) to // reload the config. log.Println("waiting forever. Try making a change to mock/mock.json to live reload") diff --git a/go.work.sum b/go.work.sum index c2837736..eab7bcf8 100644 --- a/go.work.sum +++ b/go.work.sum @@ -145,6 +145,7 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4= golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= diff --git a/providers/file/file.go b/providers/file/file.go index 4ccc62af..b3990125 100644 --- a/providers/file/file.go +++ b/providers/file/file.go @@ -8,6 +8,7 @@ import ( "fmt" "os" "path/filepath" + "sync/atomic" "time" "github.com/fsnotify/fsnotify" @@ -16,6 +17,12 @@ import ( // File implements a File provider. type File struct { path string + w *fsnotify.Watcher + + // Using Int32 for Go 1.18 backwards compatibility. + // Bool was added in 1.19 + isWatching atomic.Int32 + isUnwatched atomic.Int32 } // Provider returns a file provider. @@ -36,6 +43,11 @@ func (f *File) Read() (map[string]interface{}, error) { // Watch watches the file and triggers a callback when it changes. It is a // blocking function that internally spawns a goroutine to watch for changes. func (f *File) Watch(cb func(event interface{}, err error)) error { + // If a watcher already exists, return an error. + if f.isWatching.Load() == 1 { + return errors.New("file is already being watched") + } + // Resolve symlinks and save the original path so that changes to symlinks // can be detected. realPath, err := filepath.EvalSymlinks(f.path) @@ -48,11 +60,13 @@ func (f *File) Watch(cb func(event interface{}, err error)) error { // the whole parent directory to pick up all events such as symlink changes. fDir, _ := filepath.Split(f.path) - w, err := fsnotify.NewWatcher() + f.w, err = fsnotify.NewWatcher() if err != nil { return err } + f.isWatching.Store(1) + var ( lastEvent string lastEventTime time.Time @@ -62,9 +76,13 @@ func (f *File) Watch(cb func(event interface{}, err error)) error { loop: for { select { - case event, ok := <-w.Events: + case event, ok := <-f.w.Events: if !ok { - cb(nil, errors.New("fsnotify watch channel closed")) + // Only throw an error if it was not an explicit unwatch. + if f.isUnwatched.Load() == 0 { + cb(nil, errors.New("fsnotify watch channel closed")) + } + break loop } @@ -108,9 +126,13 @@ func (f *File) Watch(cb func(event interface{}, err error)) error { cb(nil, nil) // There's an error. - case err, ok := <-w.Errors: + case err, ok := <-f.w.Errors: if !ok { - cb(nil, errors.New("fsnotify err channel closed")) + // Only throw an error if it was not an explicit unwatch. + if f.isUnwatched.Load() == 0 { + cb(nil, errors.New("fsnotify err channel closed")) + } + break loop } @@ -120,9 +142,17 @@ func (f *File) Watch(cb func(event interface{}, err error)) error { } } - w.Close() + f.isWatching.Store(0) + f.isUnwatched.Store(0) + f.w.Close() }() // Watch the directory for changes. - return w.Add(fDir) + return f.w.Add(fDir) +} + +// Unwatch stops watching the files and closes fsnotify watcher. +func (f *File) Unwatch() error { + f.isUnwatched.Store(1) + return f.w.Close() } diff --git a/tests/koanf_test.go b/tests/koanf_test.go index edf5eeb8..16ab9537 100644 --- a/tests/koanf_test.go +++ b/tests/koanf_test.go @@ -533,6 +533,55 @@ func TestWatchFileSymlink(t *testing.T) { }, "symlink watch reload didn't change config") } +func TestUnwatchFile(t *testing.T) { + var ( + assert = assert.New(t) + k = koanf.New(delim) + ) + + // Create a tmp config file. + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "koanf_mock") + require.NoError(t, os.WriteFile(tmpFile, []byte(`{"parent": {"name": "name1"}}`), 0600)) + + // Load the new config file. + f := file.Provider(tmpFile) + k.Load(f, json.Parser()) + + // Watch. + reloaded := false + f.Watch(func(event interface{}, err error) { + reloaded = true + assert.NoError(err) + }) + + // Change the file and check whether the watch triggered. + time.Sleep(100 * time.Millisecond) + os.WriteFile(tmpFile, []byte(`{"parent": {"name": "name2"}}`), 0600) + time.Sleep(100 * time.Millisecond) + assert.True(reloaded, "watched file didn't reload") + + // Unwatch the file and verify that the watch didn't triger. + assert.NoError(f.Unwatch()) + reloaded = false + time.Sleep(100 * time.Millisecond) + os.WriteFile(tmpFile, []byte(`{"parent": {"name": "name3"}}`), 0600) + time.Sleep(100 * time.Millisecond) + assert.False(reloaded, "unwatched file reloaded") + + // Re-watch and check again. + reloaded = false + f.Watch(func(event interface{}, err error) { + reloaded = true + assert.NoError(err) + }) + os.WriteFile(tmpFile, []byte(`{"parent": {"name": "name4"}}`), 0600) + time.Sleep(100 * time.Millisecond) + assert.True(reloaded, "watched file didn't reload") + + f.Unwatch() +} + func TestLoadMerge(t *testing.T) { var ( assert = assert.New(t)