diff --git a/etcdmain/grpc_proxy.go b/etcdmain/grpc_proxy.go index e0fc331d2907..b0a59c0f2dab 100644 --- a/etcdmain/grpc_proxy.go +++ b/etcdmain/grpc_proxy.go @@ -206,6 +206,8 @@ func mustNewClient() *clientv3.Client { } cfg.DialOptions = append(cfg.DialOptions, grpc.WithUnaryInterceptor(grpcproxy.AuthUnaryClientInterceptor)) + cfg.DialOptions = append(cfg.DialOptions, + grpc.WithStreamInterceptor(grpcproxy.AuthStreamClientInterceptor)) client, err := clientv3.New(*cfg) if err != nil { fmt.Fprintln(os.Stderr, err) diff --git a/proxy/grpcproxy/util.go b/proxy/grpcproxy/util.go index b8656a8be2fd..f0186ca35dbf 100644 --- a/proxy/grpcproxy/util.go +++ b/proxy/grpcproxy/util.go @@ -53,3 +53,12 @@ func AuthUnaryClientInterceptor(ctx context.Context, method string, req, reply i } return invoker(ctx, method, req, reply, cc, opts...) } + +func AuthStreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + tokenif := ctx.Value("token") + if tokenif != nil { + tokenCred := &proxyTokenCredential{tokenif.(string)} + opts = append(opts, grpc.PerRPCCredentials(tokenCred)) + } + return streamer(ctx, desc, cc, method, opts...) +} diff --git a/proxy/grpcproxy/watch.go b/proxy/grpcproxy/watch.go index b960c94769ac..b18a4e22093a 100644 --- a/proxy/grpcproxy/watch.go +++ b/proxy/grpcproxy/watch.go @@ -40,6 +40,9 @@ type watchProxy struct { // wg waits until all outstanding watch servers quit. wg sync.WaitGroup + + // kv is used for permission checking + kv clientv3.KV } func NewWatchProxy(c *clientv3.Client) (pb.WatchServer, <-chan struct{}) { @@ -48,6 +51,8 @@ func NewWatchProxy(c *clientv3.Client) (pb.WatchServer, <-chan struct{}) { cw: c.Watcher, ctx: cctx, leader: newLeader(c.Ctx(), c.Watcher), + + kv: c.KV, // for permission checking } wp.ranges = newWatchRanges(wp) ch := make(chan struct{}) @@ -92,6 +97,7 @@ func (wp *watchProxy) Watch(stream pb.Watch_WatchServer) (err error) { watchCh: make(chan *pb.WatchResponse, 1024), ctx: ctx, cancel: cancel, + kv: wp.kv, } var lostLeaderC <-chan struct{} @@ -171,6 +177,9 @@ type watchProxyStream struct { ctx context.Context cancel context.CancelFunc + + // kv is used for permission checking + kv clientv3.KV } func (wps *watchProxyStream) close() { @@ -192,6 +201,15 @@ func (wps *watchProxyStream) close() { close(wps.watchCh) } +func (wps *watchProxyStream) checkPermissionForWatch(key, rangeEnd []byte) error { + req := &pb.RangeRequest{ + Key: key, + RangeEnd: rangeEnd, + } + _, err := wps.kv.Do(wps.ctx, RangeRequestToOp(req)) + return err +} + func (wps *watchProxyStream) recvLoop() error { for { req, err := wps.stream.Recv() @@ -201,6 +219,11 @@ func (wps *watchProxyStream) recvLoop() error { switch uv := req.RequestUnion.(type) { case *pb.WatchRequest_CreateRequest: cr := uv.CreateRequest + + if err = wps.checkPermissionForWatch(cr.Key, cr.RangeEnd); err != nil { + return err + } + w := &watcher{ wr: watchRange{string(cr.Key), string(cr.RangeEnd)}, id: wps.nextWatcherID, diff --git a/proxy/grpcproxy/watch_broadcast.go b/proxy/grpcproxy/watch_broadcast.go index 5e750bdb0d40..85cb8ee0a47e 100644 --- a/proxy/grpcproxy/watch_broadcast.go +++ b/proxy/grpcproxy/watch_broadcast.go @@ -59,6 +59,12 @@ func newWatchBroadcast(wp *watchProxy, w *watcher, update func(*watchBroadcast)) clientv3.WithCreatedNotify(), } + // Forward a token from client to server. + token := getAuthTokenFromClient(w.wps.stream.Context()) + if token != "" { + cctx = context.WithValue(cctx, "token", token) + } + wch := wp.cw.Watch(cctx, w.wr.key, opts...) for wr := range wch {