Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Unwatch() to file provider. #306

Merged
merged 1 commit into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ func main() {
k.Print()
})

// To stop a file watcher, call:
// f.Unwatch()

// Block forever (and manually make a change to mock/mock.json) to
// reload the config.
log.Println("waiting forever. Try making a change to mock/mock.json to live reload")
Expand Down
1 change: 1 addition & 0 deletions go.work.sum
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4=
golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0=
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
Expand Down
43 changes: 36 additions & 7 deletions providers/file/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"os"
"path/filepath"
"sync/atomic"
"time"

"github.com/fsnotify/fsnotify"
Expand All @@ -16,6 +17,11 @@ import (
// File implements a File provider.
type File struct {
path string
w *fsnotify.Watcher

// Using Go 1.18 atomic functions for backwards compatibility.
isWatching uint32
isUnwatched uint32
}

// Provider returns a file provider.
Expand All @@ -36,6 +42,11 @@ func (f *File) Read() (map[string]interface{}, error) {
// Watch watches the file and triggers a callback when it changes. It is a
// blocking function that internally spawns a goroutine to watch for changes.
func (f *File) Watch(cb func(event interface{}, err error)) error {
// If a watcher already exists, return an error.
if atomic.LoadUint32(&f.isWatching) == 1 {
return errors.New("file is already being watched")
}

// Resolve symlinks and save the original path so that changes to symlinks
// can be detected.
realPath, err := filepath.EvalSymlinks(f.path)
Expand All @@ -48,11 +59,13 @@ func (f *File) Watch(cb func(event interface{}, err error)) error {
// the whole parent directory to pick up all events such as symlink changes.
fDir, _ := filepath.Split(f.path)

w, err := fsnotify.NewWatcher()
f.w, err = fsnotify.NewWatcher()
if err != nil {
return err
}

atomic.StoreUint32(&f.isWatching, 1)

var (
lastEvent string
lastEventTime time.Time
Expand All @@ -62,9 +75,13 @@ func (f *File) Watch(cb func(event interface{}, err error)) error {
loop:
for {
select {
case event, ok := <-w.Events:
case event, ok := <-f.w.Events:
if !ok {
cb(nil, errors.New("fsnotify watch channel closed"))
// Only throw an error if it was not an explicit unwatch.
if atomic.LoadUint32(&f.isUnwatched) == 0 {
cb(nil, errors.New("fsnotify watch channel closed"))
}

break loop
}

Expand Down Expand Up @@ -108,9 +125,13 @@ func (f *File) Watch(cb func(event interface{}, err error)) error {
cb(nil, nil)

// There's an error.
case err, ok := <-w.Errors:
case err, ok := <-f.w.Errors:
if !ok {
cb(nil, errors.New("fsnotify err channel closed"))
// Only throw an error if it was not an explicit unwatch.
if atomic.LoadUint32(&f.isUnwatched) == 0 {
cb(nil, errors.New("fsnotify err channel closed"))
}

break loop
}

Expand All @@ -120,9 +141,17 @@ func (f *File) Watch(cb func(event interface{}, err error)) error {
}
}

w.Close()
atomic.StoreUint32(&f.isWatching, 0)
atomic.StoreUint32(&f.isUnwatched, 0)
f.w.Close()
}()

// Watch the directory for changes.
return w.Add(fDir)
return f.w.Add(fDir)
}

// Unwatch stops watching the files and closes fsnotify watcher.
func (f *File) Unwatch() error {
atomic.StoreUint32(&f.isUnwatched, 1)
return f.w.Close()
}
49 changes: 49 additions & 0 deletions tests/koanf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,55 @@ func TestWatchFileSymlink(t *testing.T) {
}, "symlink watch reload didn't change config")
}

func TestUnwatchFile(t *testing.T) {
var (
assert = assert.New(t)
k = koanf.New(delim)
)

// Create a tmp config file.
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "koanf_mock")
require.NoError(t, os.WriteFile(tmpFile, []byte(`{"parent": {"name": "name1"}}`), 0600))

// Load the new config file.
f := file.Provider(tmpFile)
k.Load(f, json.Parser())

// Watch.
reloaded := false
f.Watch(func(event interface{}, err error) {
reloaded = true
assert.NoError(err)
})

// Change the file and check whether the watch triggered.
time.Sleep(100 * time.Millisecond)
os.WriteFile(tmpFile, []byte(`{"parent": {"name": "name2"}}`), 0600)
time.Sleep(100 * time.Millisecond)
assert.True(reloaded, "watched file didn't reload")

// Unwatch the file and verify that the watch didn't triger.
assert.NoError(f.Unwatch())
reloaded = false
time.Sleep(100 * time.Millisecond)
os.WriteFile(tmpFile, []byte(`{"parent": {"name": "name3"}}`), 0600)
time.Sleep(100 * time.Millisecond)
assert.False(reloaded, "unwatched file reloaded")

// Re-watch and check again.
reloaded = false
f.Watch(func(event interface{}, err error) {
reloaded = true
assert.NoError(err)
})
os.WriteFile(tmpFile, []byte(`{"parent": {"name": "name4"}}`), 0600)
time.Sleep(100 * time.Millisecond)
assert.True(reloaded, "watched file didn't reload")

f.Unwatch()
}

func TestLoadMerge(t *testing.T) {
var (
assert = assert.New(t)
Expand Down
Loading