diff --git a/.changelog/15101.txt b/.changelog/15101.txt new file mode 100644 index 00000000000..c76126f7918 --- /dev/null +++ b/.changelog/15101.txt @@ -0,0 +1,3 @@ +```release-note:bug +csi: Fixed race condition that can cause a panic when volume is garbage collected +``` diff --git a/nomad/volumewatcher/volume_watcher.go b/nomad/volumewatcher/volume_watcher.go index 82cd76fb849..f0153e3a9f0 100644 --- a/nomad/volumewatcher/volume_watcher.go +++ b/nomad/volumewatcher/volume_watcher.go @@ -74,7 +74,6 @@ func (vw *volumeWatcher) Notify(v *structs.CSIVolume) { select { case vw.updateCh <- v: case <-vw.shutdownCtx.Done(): // prevent deadlock if we stopped - case <-vw.ctx.Done(): // prevent deadlock if we stopped } } @@ -83,17 +82,14 @@ func (vw *volumeWatcher) Start() { vw.wLock.Lock() defer vw.wLock.Unlock() vw.running = true - ctx, exitFn := context.WithCancel(vw.shutdownCtx) - vw.ctx = ctx - vw.exitFn = exitFn go vw.watch() } -// Stop stops watching the volume. This should be called whenever a -// volume's claims are fully reaped or the watcher is no longer needed. func (vw *volumeWatcher) Stop() { vw.logger.Trace("no more claims") - vw.exitFn() + vw.wLock.Lock() + defer vw.wLock.Unlock() + vw.running = false } func (vw *volumeWatcher) isRunning() bool { @@ -102,8 +98,6 @@ func (vw *volumeWatcher) isRunning() bool { select { case <-vw.shutdownCtx.Done(): return false - case <-vw.ctx.Done(): - return false default: return vw.running } @@ -113,12 +107,8 @@ func (vw *volumeWatcher) isRunning() bool { // Each pass steps the volume's claims through the various states of reaping // until the volume has no more claims eligible to be reaped. func (vw *volumeWatcher) watch() { - // always denormalize the volume and call reap when we first start - // the watcher so that we ensure we don't drop events that - // happened during leadership transitions and didn't get completed - // by the prior leader - vol := vw.getVolume(vw.v) - vw.volumeReap(vol) + defer vw.deleteFn() + defer vw.Stop() timer, stop := helper.NewSafeTimer(vw.quiescentTimeout) defer stop() @@ -129,31 +119,17 @@ func (vw *volumeWatcher) watch() { // context, so we can't stop the long-runner RPCs gracefully case <-vw.shutdownCtx.Done(): return - case <-vw.ctx.Done(): - return case vol := <-vw.updateCh: vol = vw.getVolume(vol) if vol == nil { - // We stop the goroutine whenever we have no more - // work, but only delete the watcher when the volume - // is gone to avoid racing the blocking query - vw.deleteFn() - vw.Stop() return } vw.volumeReap(vol) timer.Reset(vw.quiescentTimeout) case <-timer.C: - // Wait until the volume has "settled" before stopping - // this goroutine so that the race between shutdown and - // the parent goroutine sending on <-updateCh is pushed to - // after the window we most care about quick freeing of - // claims (and the GC job will clean up anything we miss) - vol = vw.getVolume(vol) - if vol == nil { - vw.deleteFn() - } - vw.Stop() + // Wait until the volume has "settled" before stopping this + // goroutine so that we can handle the burst of updates around + // freeing claims without having to spin it back up return } } diff --git a/nomad/volumewatcher/volumes_watcher.go b/nomad/volumewatcher/volumes_watcher.go index 5ae08e1ffb9..7f95ed659d4 100644 --- a/nomad/volumewatcher/volumes_watcher.go +++ b/nomad/volumewatcher/volumes_watcher.go @@ -188,6 +188,14 @@ func (w *Watcher) addLocked(v *structs.CSIVolume) (*volumeWatcher, error) { watcher := newVolumeWatcher(w, v) w.watchers[v.ID+v.Namespace] = watcher + + // Sending the first volume update here before we return ensures we've hit + // the run loop in the goroutine before freeing the lock. This prevents a + // race between shutting down the watcher and the blocking query. + // + // It also ensures that we don't drop events that happened during leadership + // transitions and didn't get completed by the prior leader + watcher.updateCh <- v return watcher, nil } diff --git a/nomad/volumewatcher/volumes_watcher_test.go b/nomad/volumewatcher/volumes_watcher_test.go index b4deda0c7b1..142b83dd450 100644 --- a/nomad/volumewatcher/volumes_watcher_test.go +++ b/nomad/volumewatcher/volumes_watcher_test.go @@ -10,6 +10,7 @@ import ( "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" + "github.com/shoenig/test/must" "github.com/stretchr/testify/require" ) @@ -17,7 +18,6 @@ import ( // to happen during leader step-up/step-down func TestVolumeWatch_EnableDisable(t *testing.T) { ci.Parallel(t) - require := require.New(t) srv := &MockRPCServer{} srv.state = state.TestStateStore(t) @@ -36,7 +36,7 @@ func TestVolumeWatch_EnableDisable(t *testing.T) { index++ err := srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol}) - require.NoError(err) + require.NoError(t, err) // need to have just enough of a volume and claim in place so that // the watcher doesn't immediately stop and unload itself @@ -46,22 +46,23 @@ func TestVolumeWatch_EnableDisable(t *testing.T) { } index++ err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) - require.NoError(err) - require.Eventually(func() bool { + require.NoError(t, err) + require.Eventually(t, func() bool { watcher.wlock.RLock() defer watcher.wlock.RUnlock() return 1 == len(watcher.watchers) }, time.Second, 10*time.Millisecond) watcher.SetEnabled(false, nil, "") - require.Equal(0, len(watcher.watchers)) + watcher.wlock.RLock() + defer watcher.wlock.RUnlock() + require.Equal(t, 0, len(watcher.watchers)) } // TestVolumeWatch_LeadershipTransition tests the correct behavior of // claim reaping across leader step-up/step-down func TestVolumeWatch_LeadershipTransition(t *testing.T) { ci.Parallel(t) - require := require.New(t) srv := &MockRPCServer{} srv.state = state.TestStateStore(t) @@ -79,25 +80,25 @@ func TestVolumeWatch_LeadershipTransition(t *testing.T) { index++ err := srv.State().UpsertAllocs(structs.MsgTypeTestSetup, index, []*structs.Allocation{alloc}) - require.NoError(err) + require.NoError(t, err) watcher.SetEnabled(true, srv.State(), "") index++ err = srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol}) - require.NoError(err) + require.NoError(t, err) // we should get or start up a watcher when we get an update for // the volume from the state store - require.Eventually(func() bool { + require.Eventually(t, func() bool { watcher.wlock.RLock() defer watcher.wlock.RUnlock() return 1 == len(watcher.watchers) }, time.Second, 10*time.Millisecond) vol, _ = srv.State().CSIVolumeByID(nil, vol.Namespace, vol.ID) - require.Len(vol.PastClaims, 0, "expected to have 0 PastClaims") - require.Equal(srv.countCSIUnpublish, 0, "expected no CSI.Unpublish RPC calls") + require.Len(t, vol.PastClaims, 0, "expected to have 0 PastClaims") + require.Equal(t, srv.countCSIUnpublish, 0, "expected no CSI.Unpublish RPC calls") // trying to test a dropped watch is racy, so to reliably simulate // this condition, step-down the watcher first and then perform @@ -106,12 +107,14 @@ func TestVolumeWatch_LeadershipTransition(t *testing.T) { // step-down (this is sync) watcher.SetEnabled(false, nil, "") - require.Equal(0, len(watcher.watchers)) + watcher.wlock.RLock() + require.Equal(t, 0, len(watcher.watchers)) + watcher.wlock.RUnlock() // allocation is now invalid index++ err = srv.State().DeleteEval(index, []string{}, []string{alloc.ID}) - require.NoError(err) + require.NoError(t, err) // emit a GC so that we have a volume change that's dropped claim := &structs.CSIVolumeClaim{ @@ -122,7 +125,7 @@ func TestVolumeWatch_LeadershipTransition(t *testing.T) { } index++ err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) - require.NoError(err) + require.NoError(t, err) // create a new watcher and enable it to simulate the leadership // transition @@ -130,23 +133,21 @@ func TestVolumeWatch_LeadershipTransition(t *testing.T) { watcher.quiescentTimeout = 100 * time.Millisecond watcher.SetEnabled(true, srv.State(), "") - require.Eventually(func() bool { + require.Eventually(t, func() bool { watcher.wlock.RLock() defer watcher.wlock.RUnlock() - return 1 == len(watcher.watchers) && - !watcher.watchers[vol.ID+vol.Namespace].isRunning() + return 0 == len(watcher.watchers) }, time.Second, 10*time.Millisecond) vol, _ = srv.State().CSIVolumeByID(nil, vol.Namespace, vol.ID) - require.Len(vol.PastClaims, 1, "expected to have 1 PastClaim") - require.Equal(srv.countCSIUnpublish, 1, "expected CSI.Unpublish RPC to be called") + require.Len(t, vol.PastClaims, 1, "expected to have 1 PastClaim") + require.Equal(t, srv.countCSIUnpublish, 1, "expected CSI.Unpublish RPC to be called") } // TestVolumeWatch_StartStop tests the start and stop of the watcher when // it receives notifcations and has completed its work func TestVolumeWatch_StartStop(t *testing.T) { ci.Parallel(t) - require := require.New(t) srv := &MockStatefulRPCServer{} srv.state = state.TestStateStore(t) @@ -155,7 +156,7 @@ func TestVolumeWatch_StartStop(t *testing.T) { watcher.quiescentTimeout = 100 * time.Millisecond watcher.SetEnabled(true, srv.State(), "") - require.Equal(0, len(watcher.watchers)) + require.Equal(t, 0, len(watcher.watchers)) plugin := mock.CSIPlugin() node := testNode(plugin, srv.State()) @@ -166,23 +167,22 @@ func TestVolumeWatch_StartStop(t *testing.T) { alloc2.ClientStatus = structs.AllocClientStatusRunning index++ err := srv.State().UpsertJob(structs.MsgTypeTestSetup, index, alloc1.Job) - require.NoError(err) + require.NoError(t, err) index++ err = srv.State().UpsertAllocs(structs.MsgTypeTestSetup, index, []*structs.Allocation{alloc1, alloc2}) - require.NoError(err) + require.NoError(t, err) - // register a volume + // register a volume and an unused volume vol := testVolume(plugin, alloc1, node.ID) index++ err = srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol}) - require.NoError(err) + require.NoError(t, err) // assert we get a watcher; there are no claims so it should immediately stop - require.Eventually(func() bool { + require.Eventually(t, func() bool { watcher.wlock.RLock() defer watcher.wlock.RUnlock() - return 1 == len(watcher.watchers) && - !watcher.watchers[vol.ID+vol.Namespace].isRunning() + return 0 == len(watcher.watchers) }, time.Second*2, 10*time.Millisecond) // claim the volume for both allocs @@ -195,11 +195,11 @@ func TestVolumeWatch_StartStop(t *testing.T) { index++ err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) - require.NoError(err) + require.NoError(t, err) claim.AllocationID = alloc2.ID index++ err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) - require.NoError(err) + require.NoError(t, err) // reap the volume and assert nothing has happened claim = &structs.CSIVolumeClaim{ @@ -208,41 +208,88 @@ func TestVolumeWatch_StartStop(t *testing.T) { } index++ err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) - require.NoError(err) + require.NoError(t, err) ws := memdb.NewWatchSet() vol, _ = srv.State().CSIVolumeByID(ws, vol.Namespace, vol.ID) - require.Equal(2, len(vol.ReadAllocs)) + require.Equal(t, 2, len(vol.ReadAllocs)) // alloc becomes terminal + alloc1 = alloc1.Copy() alloc1.ClientStatus = structs.AllocClientStatusComplete index++ err = srv.State().UpsertAllocs(structs.MsgTypeTestSetup, index, []*structs.Allocation{alloc1}) - require.NoError(err) + require.NoError(t, err) index++ claim.State = structs.CSIVolumeClaimStateReadyToFree err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) - require.NoError(err) + require.NoError(t, err) - // 1 claim has been released and watcher stops - require.Eventually(func() bool { - ws := memdb.NewWatchSet() - vol, _ := srv.State().CSIVolumeByID(ws, vol.Namespace, vol.ID) - return len(vol.ReadAllocs) == 1 && len(vol.PastClaims) == 0 + // watcher stops and 1 claim has been released + require.Eventually(t, func() bool { + watcher.wlock.RLock() + defer watcher.wlock.RUnlock() + return 0 == len(watcher.watchers) + }, time.Second*5, 10*time.Millisecond) + + vol, _ = srv.State().CSIVolumeByID(ws, vol.Namespace, vol.ID) + must.Eq(t, 1, len(vol.ReadAllocs)) + must.Eq(t, 0, len(vol.PastClaims)) +} + +// TestVolumeWatch_Delete tests the stop of the watcher when it receives +// notifications around a deleted volume +func TestVolumeWatch_Delete(t *testing.T) { + ci.Parallel(t) + + srv := &MockStatefulRPCServer{} + srv.state = state.TestStateStore(t) + index := uint64(100) + watcher := NewVolumesWatcher(testlog.HCLogger(t), srv, "") + watcher.quiescentTimeout = 100 * time.Millisecond + + watcher.SetEnabled(true, srv.State(), "") + must.Eq(t, 0, len(watcher.watchers)) + + // register an unused volume + plugin := mock.CSIPlugin() + vol := mock.CSIVolume(plugin) + index++ + must.NoError(t, srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol})) + + // assert we get a watcher; there are no claims so it should immediately stop + require.Eventually(t, func() bool { + watcher.wlock.RLock() + defer watcher.wlock.RUnlock() + return 0 == len(watcher.watchers) }, time.Second*2, 10*time.Millisecond) - require.Eventually(func() bool { + // write a GC claim to the volume and then immediately delete, to + // potentially hit the race condition between updates and deletes + index++ + must.NoError(t, srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, + &structs.CSIVolumeClaim{ + Mode: structs.CSIVolumeClaimGC, + State: structs.CSIVolumeClaimStateReadyToFree, + })) + + index++ + must.NoError(t, srv.State().CSIVolumeDeregister( + index, vol.Namespace, []string{vol.ID}, false)) + + // the watcher should not be running + require.Eventually(t, func() bool { watcher.wlock.RLock() defer watcher.wlock.RUnlock() - return !watcher.watchers[vol.ID+vol.Namespace].isRunning() + return 0 == len(watcher.watchers) }, time.Second*5, 10*time.Millisecond) + } // TestVolumeWatch_RegisterDeregister tests the start and stop of // watchers around registration func TestVolumeWatch_RegisterDeregister(t *testing.T) { ci.Parallel(t) - require := require.New(t) srv := &MockStatefulRPCServer{} srv.state = state.TestStateStore(t) @@ -253,7 +300,7 @@ func TestVolumeWatch_RegisterDeregister(t *testing.T) { watcher.quiescentTimeout = 10 * time.Millisecond watcher.SetEnabled(true, srv.State(), "") - require.Equal(0, len(watcher.watchers)) + require.Equal(t, 0, len(watcher.watchers)) plugin := mock.CSIPlugin() alloc := mock.Alloc() @@ -263,18 +310,12 @@ func TestVolumeWatch_RegisterDeregister(t *testing.T) { vol := mock.CSIVolume(plugin) index++ err := srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol}) - require.NoError(err) + require.NoError(t, err) - // watcher should be started but immediately stopped - require.Eventually(func() bool { + // watcher should stop + require.Eventually(t, func() bool { watcher.wlock.RLock() defer watcher.wlock.RUnlock() - return 1 == len(watcher.watchers) + return 0 == len(watcher.watchers) }, time.Second, 10*time.Millisecond) - - require.Eventually(func() bool { - watcher.wlock.RLock() - defer watcher.wlock.RUnlock() - return !watcher.watchers[vol.ID+vol.Namespace].isRunning() - }, 1*time.Second, 10*time.Millisecond) }