diff --git a/pkg/store/proxy.go b/pkg/store/proxy.go index 55deb89910..c425110777 100644 --- a/pkg/store/proxy.go +++ b/pkg/store/proxy.go @@ -348,14 +348,14 @@ type recvResponse struct { err error } -func startFrameCtx(responseTimeout time.Duration) (context.Context, context.CancelFunc) { +func frameCtx(responseTimeout time.Duration) (context.Context, context.CancelFunc) { frameTimeoutCtx := context.Background() var cancel context.CancelFunc if responseTimeout != 0 { frameTimeoutCtx, cancel = context.WithTimeout(frameTimeoutCtx, responseTimeout) return frameTimeoutCtx, cancel } - return frameTimeoutCtx, nil + return frameTimeoutCtx, func() {} } func startStreamSeriesSet( @@ -393,43 +393,49 @@ func startStreamSeriesSet( emptyStreamResponses.Inc() } }() - for { - frameTimeoutCtx, cancel := startFrameCtx(s.responseTimeout) - if cancel != nil { - defer cancel() + + rCh := make(chan *recvResponse) + recvCancel := make(chan bool) + go func() { + for { + r, err := s.stream.Recv() + select { + case <-recvCancel: + close(rCh) + return + case rCh <- &recvResponse{r: r, err: err}: + } } - rCh := make(chan *recvResponse, 1) + }() + for { + frameTimeoutCtx, cancel := frameCtx(s.responseTimeout) + defer cancel() var rr *recvResponse - go func() { - r, err := s.stream.Recv() - rCh <- &recvResponse{r: r, err: err} - close(rCh) - }() + var err error select { case <-ctx.Done(): - s.timeoutHandling(true, ctx) + close(recvCancel) + err = errors.Wrap(ctx.Err(), fmt.Sprintf("failed to receive any data from %s", s.name)) + s.handleErr(err) return case <-frameTimeoutCtx.Done(): - s.timeoutHandling(false, frameTimeoutCtx) + close(recvCancel) + err = errors.Wrap(frameTimeoutCtx.Err(), fmt.Sprintf("failed to receive any data in %s from %s", s.responseTimeout.String(), s.name)) + s.handleErr(err) return case rr = <-rCh: } if rr.err == io.EOF { + close(recvCancel) return } if rr.err != nil { wrapErr := errors.Wrapf(rr.err, "receive series from %s", s.name) - if partialResponse { - s.warnCh.send(storepb.NewWarnSeriesResponse(wrapErr)) - return - } - - s.errMtx.Lock() - s.err = wrapErr - s.errMtx.Unlock() + s.handleErr(wrapErr) + close(recvCancel) return } numResponses++ @@ -444,13 +450,7 @@ func startStreamSeriesSet( return s } -func (s *streamSeriesSet) timeoutHandling(isQueryTimeout bool, ctx context.Context) { - var err error - if isQueryTimeout { - err = errors.Wrap(ctx.Err(), fmt.Sprintf("failed to receive any data from %s", s.name)) - } else { - err = errors.Wrap(ctx.Err(), fmt.Sprintf("failed to receive any data in %s from %s", s.responseTimeout.String(), s.name)) - } +func (s *streamSeriesSet) handleErr(err error) { s.closeSeries() if s.partialResponse { level.Warn(s.logger).Log("err", err, "msg", "returning partial response")