diff --git a/go.mod b/go.mod index 36fe98d7ec4..642a13fa5a8 100644 --- a/go.mod +++ b/go.mod @@ -103,6 +103,7 @@ require ( github.com/kr/text v0.2.0 github.com/mitchellh/mapstructure v1.5.0 github.com/nsf/jsondiff v0.0.0-20210926074059-1e845ec5d249 + github.com/spf13/afero v1.9.3 github.com/spf13/jwalterweatherman v1.1.0 github.com/xlab/treeprint v1.2.0 golang.org/x/exp v0.0.0-20230131160201-f062dba9d201 @@ -160,7 +161,6 @@ require ( github.com/rogpeppe/go-internal v1.9.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/secure-systems-lab/go-securesystemslib v0.4.0 // indirect - github.com/spf13/afero v1.9.3 // indirect github.com/spf13/cast v1.5.0 // indirect github.com/subosito/gotenv v1.4.2 // indirect github.com/tidwall/match v1.1.1 // indirect diff --git a/go/viperutil/internal/sync/sync.go b/go/viperutil/internal/sync/sync.go index 11cb028c286..a5d35c504cb 100644 --- a/go/viperutil/internal/sync/sync.go +++ b/go/viperutil/internal/sync/sync.go @@ -23,6 +23,7 @@ import ( "time" "github.com/fsnotify/fsnotify" + "github.com/spf13/afero" "github.com/spf13/pflag" "github.com/spf13/viper" @@ -48,7 +49,17 @@ type Viper struct { subscribers []chan<- struct{} watchingConfig bool + fs afero.Fs + setCh chan struct{} + + // for testing purposes only + onConfigWrite func() +} + +func (v *Viper) SetFs(fs afero.Fs) { + v.fs = fs + v.disk.SetFs(fs) } // New returns a new synced Viper. @@ -57,6 +68,7 @@ func New() *Viper { disk: viper.New(), live: viper.New(), keys: map[string]*sync.RWMutex{}, + fs: afero.NewOsFs(), // default Fs used by viper, but we need this set so loadFromDisk doesn't accidentally nil-out the live fs setCh: make(chan struct{}, 1), } } @@ -217,6 +229,10 @@ func (v *Viper) persistChanges(ctx context.Context, minWaitInterval time.Duratio // WriteConfig writes the live viper config back to disk. func (v *Viper) WriteConfig() error { + if v.onConfigWrite != nil { + defer v.onConfigWrite() + } + for _, m := range v.keys { m.Lock() // This won't fire until after the config has been written. @@ -263,6 +279,7 @@ func (v *Viper) loadFromDisk() { // Reset v.live so explicit Set calls don't win over what's just changed on // disk. v.live = viper.New() + v.live.SetFs(v.fs) // Fun fact! MergeConfigMap actually only ever returns nil. Maybe in an // older version of viper it used to actually handle errors, but now it diff --git a/go/viperutil/internal/sync/sync_internal_test.go b/go/viperutil/internal/sync/sync_internal_test.go new file mode 100644 index 00000000000..cc8a163fa18 --- /dev/null +++ b/go/viperutil/internal/sync/sync_internal_test.go @@ -0,0 +1,136 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sync + +import ( + "context" + "encoding/json" + "math/rand" + "testing" + "time" + + "github.com/spf13/afero" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPersistConfig(t *testing.T) { + type config struct { + Foo int `json:"foo"` + } + + loadConfig := func(t *testing.T, fs afero.Fs) config { + t.Helper() + + data, err := afero.ReadFile(fs, "config.json") + require.NoError(t, err) + + var cfg config + require.NoError(t, json.Unmarshal(data, &cfg)) + + return cfg + } + + setup := func(t *testing.T, v *Viper, minWaitInterval time.Duration) (afero.Fs, <-chan struct{}) { + t.Helper() + + fs := afero.NewMemMapFs() + cfg := config{ + Foo: jitter(1, 100), + } + + data, err := json.Marshal(&cfg) + require.NoError(t, err) + + err = afero.WriteFile(fs, "config.json", data, 0644) + require.NoError(t, err) + + static := viper.New() + static.SetFs(fs) + static.SetConfigFile("config.json") + + require.NoError(t, static.ReadInConfig()) + require.Equal(t, cfg.Foo, static.GetInt("foo")) + + ch := make(chan struct{}, 1) + v.onConfigWrite = func() { ch <- struct{}{} } + v.SetFs(fs) + + cancel, err := v.Watch(context.Background(), static, minWaitInterval) + require.NoError(t, err) + + t.Cleanup(cancel) + return fs, ch + } + + t.Run("basic", func(t *testing.T) { + v := New() + + minPersistWaitInterval := 10 * time.Second + get := AdaptGetter("foo", func(v *viper.Viper) func(key string) int { return v.GetInt }, v) + fs, ch := setup(t, v, minPersistWaitInterval) + + old := get("foo") + loadConfig(t, fs) + v.Set("foo", old+1) + // This should happen immediately in-memory and on-disk. + assert.Equal(t, old+1, get("foo")) + <-ch + assert.Equal(t, old+1, loadConfig(t, fs).Foo) + + v.Set("foo", old+2) + // This should _also_ happen immediately in-memory, but not on-disk. + // It will take up to 2 * minPersistWaitInterval to reach the disk. + assert.Equal(t, old+2, get("foo")) + assert.Equal(t, old+1, loadConfig(t, fs).Foo) + + select { + case <-ch: + case <-time.After(3 * minPersistWaitInterval): + assert.Fail(t, "config was not persisted quickly enough", "config took longer than %s to persist (minPersistWaitInterval = %s)", 3*minPersistWaitInterval, minPersistWaitInterval) + } + + assert.Equal(t, old+2, loadConfig(t, fs).Foo) + }) + + t.Run("no persist interval", func(t *testing.T) { + v := New() + + var minPersistWaitInterval time.Duration + get := AdaptGetter("foo", func(v *viper.Viper) func(key string) int { return v.GetInt }, v) + fs, ch := setup(t, v, minPersistWaitInterval) + + old := get("foo") + loadConfig(t, fs) + v.Set("foo", old+1) + // This should happen immediately in-memory and on-disk. + assert.Equal(t, old+1, get("foo")) + <-ch + assert.Equal(t, old+1, loadConfig(t, fs).Foo) + + v.Set("foo", old+2) + // This should _also_ happen immediately in-memory, and on-disk. + assert.Equal(t, old+2, get("foo")) + <-ch + assert.Equal(t, old+2, loadConfig(t, fs).Foo) + }) +} + +func jitter(min, max int) int { + return min + rand.Intn(max-min+1) +} diff --git a/go/viperutil/internal/sync/sync_test.go b/go/viperutil/internal/sync/sync_test.go index df494c19bae..50e46a2c240 100644 --- a/go/viperutil/internal/sync/sync_test.go +++ b/go/viperutil/internal/sync/sync_test.go @@ -22,14 +22,12 @@ import ( "fmt" "math/rand" "os" - "strings" "sync" "testing" "time" "github.com/fsnotify/fsnotify" "github.com/spf13/viper" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "vitess.io/vitess/go/viperutil" @@ -37,107 +35,6 @@ import ( "vitess.io/vitess/go/viperutil/internal/value" ) -func TestPersistConfig(t *testing.T) { - t.Skip("temporarily skipping this to unblock PRs since it's flaky") - type config struct { - Foo int `json:"foo"` - } - - loadConfig := func(t *testing.T, f *os.File) config { - t.Helper() - - data, err := os.ReadFile(f.Name()) - require.NoError(t, err) - - var cfg config - require.NoError(t, json.Unmarshal(data, &cfg)) - - return cfg - } - - setup := func(t *testing.T, v *vipersync.Viper, minWaitInterval time.Duration) (*os.File, chan struct{}) { - tmp, err := os.CreateTemp(t.TempDir(), fmt.Sprintf("%s_*.json", strings.ReplaceAll(t.Name(), "/", "_"))) - require.NoError(t, err) - - t.Cleanup(func() { os.Remove(tmp.Name()) }) - - cfg := config{ - Foo: jitter(1, 100), - } - - data, err := json.Marshal(&cfg) - require.NoError(t, err) - - _, err = tmp.Write(data) - require.NoError(t, err) - - static := viper.New() - static.SetConfigFile(tmp.Name()) - require.NoError(t, static.ReadInConfig()) - - ch := make(chan struct{}, 1) - v.Notify(ch) - - cancel, err := v.Watch(context.Background(), static, minWaitInterval) - require.NoError(t, err) - t.Cleanup(cancel) - - return tmp, ch - } - - t.Run("basic", func(t *testing.T) { - v := vipersync.New() - - minPersistWaitInterval := 10 * time.Second - get := vipersync.AdaptGetter("foo", viperutil.GetFuncForType[int](), v) - f, ch := setup(t, v, minPersistWaitInterval) - - old := get("foo") - loadConfig(t, f) - v.Set("foo", old+1) - // This should happen immediately in-memory and on-disk. - assert.Equal(t, old+1, get("foo")) - <-ch - assert.Equal(t, old+1, loadConfig(t, f).Foo) - - v.Set("foo", old+2) - // This should _also_ happen immediately in-memory, but not on-disk. - // It will take up to 2 * minPersistWaitInterval to reach the disk. - assert.Equal(t, old+2, get("foo")) - assert.Equal(t, old+1, loadConfig(t, f).Foo) - - select { - case <-ch: - case <-time.After(2 * minPersistWaitInterval): - assert.Fail(t, "config was not persisted quickly enough", "config took longer than %s to persist (minPersistWaitInterval = %s)", 2*minPersistWaitInterval, minPersistWaitInterval) - } - - assert.Equal(t, old+2, loadConfig(t, f).Foo) - }) - - t.Run("no persist interval", func(t *testing.T) { - v := vipersync.New() - - var minPersistWaitInterval time.Duration - get := vipersync.AdaptGetter("foo", viperutil.GetFuncForType[int](), v) - f, ch := setup(t, v, minPersistWaitInterval) - - old := get("foo") - loadConfig(t, f) - v.Set("foo", old+1) - // This should happen immediately in-memory and on-disk. - assert.Equal(t, old+1, get("foo")) - <-ch - assert.Equal(t, old+1, loadConfig(t, f).Foo) - - v.Set("foo", old+2) - // This should _also_ happen immediately in-memory, and on-disk. - assert.Equal(t, old+2, get("foo")) - <-ch - assert.Equal(t, old+2, loadConfig(t, f).Foo) - }) -} - func TestWatchConfig(t *testing.T) { type config struct { A, B int @@ -156,7 +53,26 @@ func TestWatchConfig(t *testing.T) { return err } - return os.WriteFile(tmp.Name(), data, stat.Mode()) + err = os.WriteFile(tmp.Name(), data, stat.Mode()) + if err != nil { + return err + } + + data, err = os.ReadFile(tmp.Name()) + if err != nil { + return err + } + + var cfg config + if err := json.Unmarshal(data, &cfg); err != nil { + return err + } + + if cfg.A != a || cfg.B != b { + return fmt.Errorf("config did not persist; want %+v got %+v", config{A: a, B: b}, cfg) + } + + return nil } writeRandomConfig := func() error { a, b := rand.Intn(100), rand.Intn(100)