diff --git a/server/etcdserver/api/v3rpc/watch.go b/server/etcdserver/api/v3rpc/watch.go index 5153007258d..a8d37efa098 100644 --- a/server/etcdserver/api/v3rpc/watch.go +++ b/server/etcdserver/api/v3rpc/watch.go @@ -145,6 +145,10 @@ type serverWatchStream struct { // records fragmented watch IDs fragment map[mvcc.WatchID]bool + // indicates whether we have an outstanding global progress + // notification to send + deferredProgress bool + // closec indicates the stream is closed. closec chan struct{} @@ -174,6 +178,8 @@ func (ws *watchServer) Watch(stream pb.Watch_WatchServer) (err error) { prevKV: make(map[mvcc.WatchID]bool), fragment: make(map[mvcc.WatchID]bool), + deferredProgress: false, + closec: make(chan struct{}), } @@ -360,10 +366,16 @@ func (sws *serverWatchStream) recvLoop() error { } case *pb.WatchRequest_ProgressRequest: if uv.ProgressRequest != nil { - sws.ctrlStream <- &pb.WatchResponse{ - Header: sws.newResponseHeader(sws.watchStream.Rev()), - WatchId: clientv3.InvalidWatchID, // response is not associated with any WatchId and will be broadcast to all watch channels + sws.mu.Lock() + // Ignore if deferred progress notification is already in progress + if !sws.deferredProgress { + // Request progress for all watchers, + // force generation of a response + if !sws.watchStream.RequestProgressAll() { + sws.deferredProgress = true + } } + sws.mu.Unlock() } default: // we probably should not shutdown the entire stream when @@ -432,11 +444,15 @@ func (sws *serverWatchStream) sendLoop() { Canceled: canceled, } - if _, okID := ids[wresp.WatchID]; !okID { - // buffer if id not yet announced - wrs := append(pending[wresp.WatchID], wr) - pending[wresp.WatchID] = wrs - continue + // Progress notifications can have WatchID -1 + // if they announce on behalf of multiple watchers + if wresp.WatchID != clientv3.InvalidWatchID { + if _, okID := ids[wresp.WatchID]; !okID { + // buffer if id not yet announced + wrs := append(pending[wresp.WatchID], wr) + pending[wresp.WatchID] = wrs + continue + } } mvcc.ReportEventReceived(len(evs)) @@ -467,6 +483,11 @@ func (sws *serverWatchStream) sendLoop() { // elide next progress update if sent a key update sws.progress[wresp.WatchID] = false } + if sws.deferredProgress { + if sws.watchStream.RequestProgressAll() { + sws.deferredProgress = false + } + } sws.mu.Unlock() case c, ok := <-sws.ctrlStream: diff --git a/server/storage/mvcc/watchable_store.go b/server/storage/mvcc/watchable_store.go index f0d056f286d..4e7b5a71407 100644 --- a/server/storage/mvcc/watchable_store.go +++ b/server/storage/mvcc/watchable_store.go @@ -19,6 +19,7 @@ import ( "time" "go.etcd.io/etcd/api/v3/mvccpb" + clientv3 "go.etcd.io/etcd/client/v3" "go.etcd.io/etcd/pkg/v3/traceutil" "go.etcd.io/etcd/server/v3/lease" "go.etcd.io/etcd/server/v3/storage/backend" @@ -41,6 +42,7 @@ var ( type watchable interface { watch(key, end []byte, startRev int64, id WatchID, ch chan<- WatchResponse, fcs ...FilterFunc) (*watcher, cancelFunc) progress(w *watcher) + progressAll(watchers map[WatchID]*watcher) bool rev() int64 } @@ -475,14 +477,34 @@ func (s *watchableStore) addVictim(victim watcherBatch) { func (s *watchableStore) rev() int64 { return s.store.Rev() } func (s *watchableStore) progress(w *watcher) { + s.progressIfSync(map[WatchID]*watcher{w.id: w}, w.id) +} + +func (s *watchableStore) progressAll(watchers map[WatchID]*watcher) bool { + return s.progressIfSync(watchers, clientv3.InvalidWatchID) +} + +func (s *watchableStore) progressIfSync(watchers map[WatchID]*watcher, responseWatchID WatchID) bool { s.mu.RLock() defer s.mu.RUnlock() - if _, ok := s.synced.watchers[w]; ok { - w.send(WatchResponse{WatchID: w.id, Revision: s.rev()}) - // If the ch is full, this watcher is receiving events. - // We do not need to send progress at all. + // Any watcher unsynced? + for _, w := range watchers { + if _, ok := s.synced.watchers[w]; !ok { + return false + } + } + + // If all watchers are synchronised, send out progress + // notification on first watcher. Note that all watchers + // should have the same underlying stream, and the progress + // notification will be broadcasted client-side if required + // (see dispatchEvent in client/v3/watch.go) + for _, w := range watchers { + w.send(WatchResponse{WatchID: responseWatchID, Revision: s.rev()}) + return true } + return true } type watcher struct { diff --git a/server/storage/mvcc/watcher.go b/server/storage/mvcc/watcher.go index 7d2490b1d6e..c67c21d6139 100644 --- a/server/storage/mvcc/watcher.go +++ b/server/storage/mvcc/watcher.go @@ -58,6 +58,13 @@ type WatchStream interface { // of the watchers since the watcher is currently synced. RequestProgress(id WatchID) + // RequestProgressAll requests a progress notification for all + // watchers sharing the stream. If all watchers are synced, a + // progress notification with watch ID -1 will be sent to an + // arbitrary watcher of this stream, and the function returns + // true. + RequestProgressAll() bool + // Cancel cancels a watcher by giving its ID. If watcher does not exist, an error will be // returned. Cancel(id WatchID) error @@ -188,3 +195,9 @@ func (ws *watchStream) RequestProgress(id WatchID) { } ws.watchable.progress(w) } + +func (ws *watchStream) RequestProgressAll() bool { + ws.mu.Lock() + defer ws.mu.Unlock() + return ws.watchable.progressAll(ws.watchers) +} diff --git a/server/storage/mvcc/watcher_test.go b/server/storage/mvcc/watcher_test.go index b86e31a5542..41bbb510875 100644 --- a/server/storage/mvcc/watcher_test.go +++ b/server/storage/mvcc/watcher_test.go @@ -25,6 +25,7 @@ import ( "go.uber.org/zap/zaptest" "go.etcd.io/etcd/api/v3/mvccpb" + clientv3 "go.etcd.io/etcd/client/v3" "go.etcd.io/etcd/server/v3/lease" betesting "go.etcd.io/etcd/server/v3/storage/backend/testing" ) @@ -342,6 +343,55 @@ func TestWatcherRequestProgress(t *testing.T) { } } +func TestWatcherRequestProgressAll(t *testing.T) { + b, _ := betesting.NewDefaultTmpBackend(t) + + // manually create watchableStore instead of newWatchableStore + // because newWatchableStore automatically calls syncWatchers + // method to sync watchers in unsynced map. We want to keep watchers + // in unsynced to test if syncWatchers works as expected. + s := &watchableStore{ + store: NewStore(zaptest.NewLogger(t), b, &lease.FakeLessor{}, StoreConfig{}), + unsynced: newWatcherGroup(), + synced: newWatcherGroup(), + stopc: make(chan struct{}), + } + + defer cleanup(s, b) + + testKey := []byte("foo") + notTestKey := []byte("bad") + testValue := []byte("bar") + s.Put(testKey, testValue, lease.NoLease) + + // Create watch stream with watcher. We will not actually get + // any notifications on it specifically, but there needs to be + // at least one Watch for progress notifications to get + // generated. + w := s.NewWatchStream() + w.Watch(0, notTestKey, nil, 1) + + w.RequestProgressAll() + select { + case resp := <-w.Chan(): + t.Fatalf("unexpected %+v", resp) + default: + } + + s.syncWatchers() + + w.RequestProgressAll() + wrs := WatchResponse{WatchID: clientv3.InvalidWatchID, Revision: 2} + select { + case resp := <-w.Chan(): + if !reflect.DeepEqual(resp, wrs) { + t.Fatalf("got %+v, expect %+v", resp, wrs) + } + case <-time.After(time.Second): + t.Fatal("failed to receive progress") + } +} + func TestWatcherWatchWithFilter(t *testing.T) { b, _ := betesting.NewDefaultTmpBackend(t) s := WatchableKV(newWatchableStore(zaptest.NewLogger(t), b, &lease.FakeLessor{}, StoreConfig{})) diff --git a/tests/integration/v3_watch_test.go b/tests/integration/v3_watch_test.go index c2a4fa57645..ead81abc174 100644 --- a/tests/integration/v3_watch_test.go +++ b/tests/integration/v3_watch_test.go @@ -1397,3 +1397,71 @@ func TestV3WatchCloseCancelRace(t *testing.T) { t.Fatalf("expected %s watch, got %s", expected, minWatches) } } + +// TestV3WatchProgressWaitsForSync checks that progress notifications +// don't get sent until the watcher is synchronised +func TestV3WatchProgressWaitsForSync(t *testing.T) { + + // Disable for gRPC proxy, as it does not support requesting + // progress notifications + if integration.ThroughProxy { + t.Skip("grpc proxy currently does not support requesting progress notifications") + } + + integration.BeforeTest(t) + + clus := integration.NewCluster(t, &integration.ClusterConfig{Size: 1}) + defer clus.Terminate(t) + + client := clus.RandClient() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Write a couple values into key to make sure there's a + // non-trivial amount of history. + count := 1001 + t.Logf("Writing key 'foo' %d times", count) + for i := 0; i < count; i++ { + _, err := client.Put(ctx, "foo", fmt.Sprintf("bar%d", i)) + require.NoError(t, err) + } + + // Create watch channel starting at revision 1 (i.e. it starts + // unsynced because of the update above) + wch := client.Watch(ctx, "foo", clientv3.WithRev(1)) + + // Immediately request a progress notification. As the client + // is unsynchronised, the server will have to defer the + // notification internally. + err := client.RequestProgress(ctx) + require.NoError(t, err) + + // Verify that we get the watch responses first. Note that + // events might be spread across multiple packets. + var event_count = 0 + for event_count < count { + wr := <-wch + if wr.Err() != nil { + t.Fatal(fmt.Errorf("watch error: %w", wr.Err())) + } + if wr.IsProgressNotify() { + t.Fatal("Progress notification from unsynced client!") + } + if wr.Header.Revision != int64(count+1) { + t.Fatal("Incomplete watch response!") + } + event_count += len(wr.Events) + } + + // ... followed by the requested progress notification + wr2 := <-wch + if wr2.Err() != nil { + t.Fatal(fmt.Errorf("watch error: %w", wr2.Err())) + } + if !wr2.IsProgressNotify() { + t.Fatal("Did not receive progress notification!") + } + if wr2.Header.Revision != int64(count+1) { + t.Fatal("Wrong revision in progress notification!") + } +} diff --git a/tests/robustness/linearizability_test.go b/tests/robustness/linearizability_test.go index 358533e1f16..070556c75fd 100644 --- a/tests/robustness/linearizability_test.go +++ b/tests/robustness/linearizability_test.go @@ -36,10 +36,11 @@ const ( var ( LowTraffic = trafficConfig{ - name: "LowTraffic", - minimalQPS: 100, - maximalQPS: 200, - clientCount: 8, + name: "LowTraffic", + minimalQPS: 100, + maximalQPS: 200, + clientCount: 8, + requestProgress: false, traffic: traffic{ keyCount: 10, leaseTTL: DefaultLeaseTTL, @@ -56,10 +57,11 @@ var ( }, } HighTraffic = trafficConfig{ - name: "HighTraffic", - minimalQPS: 200, - maximalQPS: 1000, - clientCount: 12, + name: "HighTraffic", + minimalQPS: 200, + maximalQPS: 1000, + clientCount: 12, + requestProgress: false, traffic: traffic{ keyCount: 10, largePutSize: 32769, @@ -71,6 +73,22 @@ var ( }, }, } + ReqProgTraffic = trafficConfig{ + name: "RequestProgressTraffic", + minimalQPS: 200, + maximalQPS: 1000, + clientCount: 12, + requestProgress: true, + traffic: traffic{ + keyCount: 10, + largePutSize: 8196, + leaseTTL: DefaultLeaseTTL, + writes: []requestChance{ + {operation: Put, chance: 95}, + {operation: LargePut, chance: 5}, + }, + }, + } defaultTraffic = LowTraffic trafficList = []trafficConfig{ LowTraffic, HighTraffic, @@ -141,6 +159,14 @@ func TestRobustness(t *testing.T) { e2e.WithSnapshotCount(100), ), }) + scenarios = append(scenarios, scenario{ + name: "Issue15220", + failpoint: RandomOneNodeClusterFailpoint, + traffic: &ReqProgTraffic, + config: *e2e.NewConfig( + e2e.WithClusterSize(1), + ), + }) snapshotOptions := []e2e.EPClusterOption{ e2e.WithGoFailEnabled(true), e2e.WithSnapshotCount(100), @@ -191,7 +217,7 @@ func testRobustness(ctx context.Context, t *testing.T, lg *zap.Logger, config e2 forcestopCluster(r.clus) watchProgressNotifyEnabled := r.clus.Cfg.WatchProcessNotifyInterval != 0 - validateWatchResponses(t, r.responses, watchProgressNotifyEnabled) + validateWatchResponses(t, r.responses, traffic.requestProgress || watchProgressNotifyEnabled) r.events = watchEvents(r.responses) validateEventsMatch(t, r.events) @@ -218,7 +244,7 @@ func runScenario(ctx context.Context, t *testing.T, lg *zap.Logger, clus *e2e.Et return nil }) g.Go(func() error { - responses = collectClusterWatchEvents(ctx, t, clus, maxRevisionChan) + responses = collectClusterWatchEvents(ctx, t, clus, maxRevisionChan, traffic.requestProgress) return nil }) g.Wait() diff --git a/tests/robustness/traffic.go b/tests/robustness/traffic.go index 84298957023..fa5889fc94e 100644 --- a/tests/robustness/traffic.go +++ b/tests/robustness/traffic.go @@ -109,11 +109,12 @@ func simulateTraffic(ctx context.Context, t *testing.T, lg *zap.Logger, clus *e2 } type trafficConfig struct { - name string - minimalQPS float64 - maximalQPS float64 - clientCount int - traffic Traffic + name string + minimalQPS float64 + maximalQPS float64 + clientCount int + traffic Traffic + requestProgress bool // Request progress notifications while watching this traffic } type Traffic interface { diff --git a/tests/robustness/watch.go b/tests/robustness/watch.go index 9038d961876..79ba408a656 100644 --- a/tests/robustness/watch.go +++ b/tests/robustness/watch.go @@ -31,7 +31,7 @@ import ( "go.etcd.io/etcd/tests/v3/robustness/model" ) -func collectClusterWatchEvents(ctx context.Context, t *testing.T, clus *e2e.EtcdProcessCluster, maxRevisionChan <-chan int64) [][]watchResponse { +func collectClusterWatchEvents(ctx context.Context, t *testing.T, clus *e2e.EtcdProcessCluster, maxRevisionChan <-chan int64, requestProgress bool) [][]watchResponse { mux := sync.Mutex{} var wg sync.WaitGroup memberResponses := make([][]watchResponse, len(clus.Procs)) @@ -52,7 +52,7 @@ func collectClusterWatchEvents(ctx context.Context, t *testing.T, clus *e2e.Etcd go func(i int, c *clientv3.Client) { defer wg.Done() defer c.Close() - responses := watchMember(ctx, t, c, memberChan) + responses := watchMember(ctx, t, c, memberChan, requestProgress) mux.Lock() memberResponses[i] = responses mux.Unlock() @@ -71,7 +71,7 @@ func collectClusterWatchEvents(ctx context.Context, t *testing.T, clus *e2e.Etcd } // watchMember collects all responses until context is cancelled, it has observed revision provided via maxRevisionChan or maxRevisionChan was closed. -func watchMember(ctx context.Context, t *testing.T, c *clientv3.Client, maxRevisionChan <-chan int64) (resps []watchResponse) { +func watchMember(ctx context.Context, t *testing.T, c *clientv3.Client, maxRevisionChan <-chan int64, requestProgress bool) (resps []watchResponse) { var maxRevision int64 = 0 var lastRevision int64 = 0 ctx, cancel := context.WithCancel(ctx) @@ -101,6 +101,9 @@ func watchMember(ctx context.Context, t *testing.T, c *clientv3.Client, maxRevis } } case resp := <-watch: + if requestProgress { + c.RequestProgress(ctx) + } if resp.Err() == nil { resps = append(resps, watchResponse{resp, time.Now()}) } else if !resp.Canceled {