Skip to content

Commit

Permalink
reuse filewatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
shawnh2 committed Nov 28, 2024
1 parent 17e932c commit 3496a33
Show file tree
Hide file tree
Showing 11 changed files with 233 additions and 450 deletions.
19 changes: 13 additions & 6 deletions internal/filewatcher/filewatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package filewatcher
import (
"errors"
"fmt"
"os"
"path/filepath"
"sync"

Expand Down Expand Up @@ -90,7 +91,6 @@ func (fw *fileWatcher) Add(path string) error {
return err
}

// Stop watching a path
func (fw *fileWatcher) Remove(path string) error {
fw.mu.Lock()
defer fw.mu.Unlock()
Expand Down Expand Up @@ -142,9 +142,7 @@ func (fw *fileWatcher) getWorker(path string) (*workerState, string, string, err
return nil, "", "", errors.New("using a closed watcher")
}

cleanedPath := filepath.Clean(path)
parentPath, _ := filepath.Split(cleanedPath)

cleanedPath, parentPath := getPath(path)
ws, workerExists := fw.workers[parentPath]
if !workerExists {
wk, err := newWorker(parentPath, fw.funcs)
Expand All @@ -167,8 +165,7 @@ func (fw *fileWatcher) findWorker(path string) (*workerState, string, error) {
return nil, "", errors.New("using a closed watcher")
}

cleanedPath := filepath.Clean(path)
parentPath, _ := filepath.Split(cleanedPath)
cleanedPath, parentPath := getPath(path)

ws, workerExists := fw.workers[parentPath]
if !workerExists {
Expand All @@ -177,3 +174,13 @@ func (fw *fileWatcher) findWorker(path string) (*workerState, string, error) {

return ws, cleanedPath, nil
}

func getPath(path string) (cleanedPath, parentPath string) {
cleanedPath = filepath.Clean(path)
parentPath, _ = filepath.Split(cleanedPath)
if f, err := os.Lstat(cleanedPath); err == nil && f.IsDir() {
parentPath = cleanedPath
}

return
}
65 changes: 50 additions & 15 deletions internal/filewatcher/filewatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"runtime"
"sync"
"testing"
"time"

"github.com/fsnotify/fsnotify"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -173,6 +174,44 @@ func TestWatchFile(t *testing.T) {
})
}

func TestWatchDir(t *testing.T) {
// Given a file being watched
watchFile := newWatchFile(t)
_, err := os.Stat(watchFile)
require.NoError(t, err)

w := NewWatcher()
defer func() {
_ = w.Close()
}()
d := path.Dir(watchFile)
require.NoError(t, w.Add(d))

timeout := time.After(5 * time.Second)

wg := sync.WaitGroup{}
var timeoutErr error
wg.Add(1)
go func() {
select {
case <-w.Events(d):

case <-w.Events(watchFile):

case <-timeout:
timeoutErr = errors.New("timeout")
}
wg.Done()
}()

// Overwriting the file and waiting its event to be received.
err = os.WriteFile(watchFile, []byte("foo: baz\n"), 0o600)
require.NoError(t, err)
wg.Wait()

require.NoErrorf(t, timeoutErr, "timeout waiting for event")
}

func TestWatcherLifecycle(t *testing.T) {
watchFile1, watchFile2 := newTwoWatchFile(t)

Expand Down Expand Up @@ -295,27 +334,23 @@ func TestBadAddWatcher(t *testing.T) {

func TestDuplicateAdd(t *testing.T) {
w := NewWatcher()

name := newWatchFile(t)
defer func() {
_ = w.Close()
_ = os.Remove(name)
}()

if err := w.Add(name); err != nil {
t.Errorf("Expecting nil, got %v", err)
}

if err := w.Add(name); err == nil {
t.Errorf("Expecting error, got nil")
}

_ = w.Close()
require.NoError(t, w.Add(name))
require.Error(t, w.Add(name))
}

func TestBogusRemove(t *testing.T) {
w := NewWatcher()

name := newWatchFile(t)
if err := w.Remove(name); err == nil {
t.Errorf("Expecting error, got nil")
}
defer func() {
_ = w.Close()
_ = os.Remove(name)
}()

_ = w.Close()
require.Error(t, w.Remove(name))
}
43 changes: 25 additions & 18 deletions internal/filewatcher/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
type worker struct {
mu sync.RWMutex

// watcher is an fsnotify watcher that watches the parent
// watcher is a fsnotify watcher that watches the parent
// dir of watchedFiles.
dirWatcher *fsnotify.Watcher

Expand Down Expand Up @@ -96,10 +96,9 @@ func (wk *worker) loop() {
continue
}

sum := getHashSum(path)
if !bytes.Equal(sum, ft.hash) {
sum, isDir := getHashSum(path)
if isDir || !bytes.Equal(sum, ft.hash) {
ft.hash = sum

select {
case ft.events <- event:
// nothing to do
Expand Down Expand Up @@ -141,7 +140,7 @@ func (wk *worker) loop() {
}
}

// used only by the worker goroutine
// drainRetiringTrackers used only by the worker goroutine
func (wk *worker) drainRetiringTrackers() {
// cleanup any trackers that were in the process
// of being retired, but didn't get processed due
Expand All @@ -156,7 +155,7 @@ func (wk *worker) drainRetiringTrackers() {
}
}

// make a local copy of the set of trackers to avoid contention with callers
// getTrackers make a local copy of the set of trackers to avoid contention with callers
// used only by the worker goroutine
func (wk *worker) getTrackers() map[string]*fileTracker {
wk.mu.RLock()
Expand Down Expand Up @@ -184,36 +183,34 @@ func (wk *worker) terminate() {

func (wk *worker) addPath(path string) error {
wk.mu.Lock()
defer wk.mu.Unlock()

ft := wk.watchedFiles[path]
if ft != nil {
wk.mu.Unlock()
return fmt.Errorf("path %s is already being watched", path)
}

h, _ := getHashSum(path)
ft = &fileTracker{
events: make(chan fsnotify.Event),
errors: make(chan error),
hash: getHashSum(path),
hash: h,
}

wk.watchedFiles[path] = ft
wk.mu.Unlock()

return nil
}

func (wk *worker) removePath(path string) error {
wk.mu.Lock()
defer wk.mu.Unlock()

ft := wk.watchedFiles[path]
if ft == nil {
wk.mu.Unlock()
return fmt.Errorf("path %s not found", path)
}

delete(wk.watchedFiles, path)
wk.mu.Unlock()

wk.retireTrackerCh <- ft
return nil
Expand Down Expand Up @@ -241,16 +238,26 @@ func (wk *worker) errorChannel(path string) chan error {
return nil
}

// gets the hash of the given file, or nil if there's a problem
func getHashSum(file string) []byte {
// getHashSum return the hash of the given file, or nil if there's a problem, or it's a directory.
func getHashSum(file string) ([]byte, bool) {
f, err := os.Open(file)
if err != nil {
return nil
return nil, false
}
defer f.Close()
r := bufio.NewReader(f)
defer func() {
_ = f.Close()
}()

fi, err := f.Stat()
if err != nil {
return nil, false
}
if fi.IsDir() {
return nil, true
}

r := bufio.NewReader(f)
h := sha256.New()
_, _ = io.Copy(h, r)
return h.Sum(nil)
return h.Sum(nil), false
}
Loading

0 comments on commit 3496a33

Please sign in to comment.