diff --git a/proxy/grpcproxy/watch.go b/proxy/grpcproxy/watch.go
index 9cfe6c72470..db3c3fc3759 100644
--- a/proxy/grpcproxy/watch.go
+++ b/proxy/grpcproxy/watch.go
@@ -59,6 +59,7 @@ func (wp *watchProxy) Watch(stream pb.Watch_WatchServer) (err error) {
 	sid := wp.nextStreamID
 	wp.mu.Unlock()
 
+	ctx, cancel := context.WithCancel(wp.ctx)
 	sws := serverWatchStream{
 		cw:       wp.cw,
 		groups:   &wp.wgs,
@@ -70,7 +71,8 @@ func (wp *watchProxy) Watch(stream pb.Watch_WatchServer) (err error) {
 
 		watchCh: make(chan *pb.WatchResponse, 1024),
 
-		proxyCtx: wp.ctx,
+		ctx:    ctx,
+		cancel: cancel,
 	}
 
 	go sws.recvLoop()
@@ -93,11 +95,13 @@ type serverWatchStream struct {
 
 	nextWatcherID int64
 
-	proxyCtx context.Context
+	ctx    context.Context
+	cancel context.CancelFunc
 }
 
 func (sws *serverWatchStream) close() {
 	var wg sync.WaitGroup
+	sws.cancel()
 	sws.mu.Lock()
 	wg.Add(len(sws.singles) + len(sws.inGroups))
 	for _, ws := range sws.singles {
@@ -145,8 +149,8 @@ func (sws *serverWatchStream) recvLoop() error {
 					key: string(cr.Key),
 					end: string(cr.RangeEnd),
 				},
-				id: sws.nextWatcherID,
-				ch: sws.watchCh,
+				id:  sws.nextWatcherID,
+				sws: sws,
 
 				progress: cr.ProgressNotify,
 				filters:  v3rpc.FiltersFromRequest(cr),
@@ -176,7 +180,7 @@ func (sws *serverWatchStream) sendLoop() {
 			if err := sws.gRPCStream.Send(wresp); err != nil {
 				return
 			}
-		case <-sws.proxyCtx.Done():
+		case <-sws.ctx.Done():
 			return
 		}
 	}
@@ -192,18 +196,15 @@ func (sws *serverWatchStream) addCoalescedWatcher(w watcher) {
 }
 
 func (sws *serverWatchStream) addDedicatedWatcher(w watcher, rev int64) {
-	sws.mu.Lock()
-	defer sws.mu.Unlock()
-
-	ctx, cancel := context.WithCancel(sws.proxyCtx)
-
+	ctx, cancel := context.WithCancel(sws.ctx)
 	wch := sws.cw.Watch(ctx,
 		w.wr.key, clientv3.WithRange(w.wr.end),
 		clientv3.WithRev(rev),
 		clientv3.WithProgressNotify(),
 		clientv3.WithCreatedNotify(),
 	)
-
+	sws.mu.Lock()
+	defer sws.mu.Unlock()
 	ws := newWatcherSingle(wch, cancel, w, sws)
 	sws.singles[w.id] = ws
 	go ws.run()
@@ -213,12 +214,11 @@ func (sws *serverWatchStream) maybeCoalesceWatcher(ws watcherSingle) bool {
 	sws.mu.Lock()
 	defer sws.mu.Unlock()
 
-	rid := receiverID{streamID: sws.id, watcherID: ws.w.id}
 	// do not add new watchers when stream is closing
 	if sws.inGroups == nil {
 		return false
 	}
-	if sws.groups.maybeJoinWatcherSingle(rid, ws) {
+	if sws.groups.maybeJoinWatcherSingle(ws) {
 		delete(sws.singles, ws.w.id)
 		sws.inGroups[ws.w.id] = struct{}{}
 		return true
diff --git a/proxy/grpcproxy/watcher.go b/proxy/grpcproxy/watcher.go
index 88c6303af9e..2d25433d972 100644
--- a/proxy/grpcproxy/watcher.go
+++ b/proxy/grpcproxy/watcher.go
@@ -28,13 +28,13 @@ type watchRange struct {
 }
 
 type watcher struct {
-	id int64
-	wr watchRange
+	id  int64
+	wr  watchRange
+	sws *serverWatchStream
 
 	rev      int64
 	filters  []mvcc.FilterFunc
 	progress bool
-	ch       chan<- *pb.WatchResponse
 }
 
 func (w *watcher) send(wr clientv3.WatchResponse) {
@@ -87,10 +87,8 @@ func (w *watcher) send(wr clientv3.WatchResponse) {
 		Events:  events,
 	}
 	select {
-	case w.ch <- pbwr:
+	case w.sws.watchCh <- pbwr:
 	case <-time.After(50 * time.Millisecond):
-		// close the watch chan will notify the stream sender.
-		// the stream will gc all its watchers.
-		close(w.ch)
+		w.sws.cancel()
 	}
 }
diff --git a/proxy/grpcproxy/watcher_group_test.go b/proxy/grpcproxy/watcher_group_test.go
index 188f9a186d3..3436bbb38c1 100644
--- a/proxy/grpcproxy/watcher_group_test.go
+++ b/proxy/grpcproxy/watcher_group_test.go
@@ -17,6 +17,8 @@ package grpcproxy
 import (
 	"testing"
 
+	"golang.org/x/net/context"
+
 	"github.com/coreos/etcd/clientv3"
 	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
 )
@@ -30,8 +32,8 @@ func TestWatchgroupBroadcast(t *testing.T) {
 	for i := range chs {
 		chs[i] = make(chan *pb.WatchResponse, 1)
 		w := watcher{
-			id: int64(i),
-			ch: chs[i],
+			id:  int64(i),
+			sws: &serverWatchStream{watchCh: chs[i], ctx: context.TODO()},
 
 			progress: true,
 		}
diff --git a/proxy/grpcproxy/watcher_groups.go b/proxy/grpcproxy/watcher_groups.go
index 2b92beee9ff..a81e6a2a127 100644
--- a/proxy/grpcproxy/watcher_groups.go
+++ b/proxy/grpcproxy/watcher_groups.go
@@ -61,8 +61,10 @@ func (wgs *watchergroups) addWatcher(rid receiverID, w watcher) {
 			WatchId: rid.watcherID,
 			Created: true,
 		}
-		w.ch <- resp
-
+		select {
+		case w.sws.watchCh <- resp:
+		case <-w.sws.ctx.Done():
+		}
 		return
 	}
 
@@ -96,24 +98,24 @@ func (wgs *watchergroups) removeWatcher(rid receiverID) (int64, bool) {
 	return -1, false
 }
 
-func (wgs *watchergroups) maybeJoinWatcherSingle(rid receiverID, ws watcherSingle) bool {
+func (wgs *watchergroups) maybeJoinWatcherSingle(ws watcherSingle) bool {
 	wgs.mu.Lock()
 	defer wgs.mu.Unlock()
 
+	rid := receiverID{streamID: ws.sws.id, watcherID: ws.w.id}
 	group, ok := wgs.groups[ws.w.wr]
 	if ok {
-		return group.add(receiverID{streamID: ws.sws.id, watcherID: ws.w.id}, ws.w) != -1
+		return group.add(rid, ws.w) != -1
 	}
-
-	if ws.canPromote() {
-		wg := newWatchergroup(ws.ch, ws.cancel)
-		wgs.groups[ws.w.wr] = wg
-		wg.add(receiverID{streamID: ws.sws.id, watcherID: ws.w.id}, ws.w)
-		go wg.run()
-		return true
+	if !ws.canPromote() {
+		return false
 	}
-
-	return false
+	wg := newWatchergroup(ws.ch, ws.cancel)
+	wgs.groups[ws.w.wr] = wg
+	wgs.idToGroup[rid] = wg
+	wg.add(rid, ws.w)
+	go wg.run()
+	return true
 }
 
 func (wgs *watchergroups) stop() {
diff --git a/proxy/grpcproxy/watcher_single.go b/proxy/grpcproxy/watcher_single.go
index d2f5f55a36a..99df1c0f134 100644
--- a/proxy/grpcproxy/watcher_single.go
+++ b/proxy/grpcproxy/watcher_single.go
@@ -54,7 +54,6 @@ func (ws watcherSingle) run() {
 	for wr := range ws.ch {
 		ws.lastStoreRev = wr.Header.Revision
 		ws.w.send(wr)
-
 		if ws.sws.maybeCoalesceWatcher(ws) {
 			return
 		}