diff --git a/cmd/ingester/app/consumer/consumer.go b/cmd/ingester/app/consumer/consumer.go index 595620361af..856336239b4 100644 --- a/cmd/ingester/app/consumer/consumer.go +++ b/cmd/ingester/app/consumer/consumer.go @@ -50,10 +50,11 @@ type Consumer struct { partitionMapLock sync.Mutex partitionsHeld int64 partitionsHeldGauge metrics.Gauge + + doneWg sync.WaitGroup } type consumerState struct { - wg sync.WaitGroup partitionConsumer sc.PartitionConsumer } @@ -78,17 +79,11 @@ func (c *Consumer) Start() { c.logger.Info("Starting main loop") for pc := range c.internalConsumer.Partitions() { c.partitionMapLock.Lock() - if p, ok := c.partitionIDToState[pc.Partition()]; ok { - // This is a guard against simultaneously draining messages - // from the last time the partition was assigned and - // processing new messages for the same partition, which may lead - // to the cleanup process not completing - p.wg.Wait() - } c.partitionIDToState[pc.Partition()] = &consumerState{partitionConsumer: pc} - c.partitionIDToState[pc.Partition()].wg.Add(2) c.partitionMapLock.Unlock() c.partitionMetrics(pc.Partition()).startCounter.Inc(1) + + c.doneWg.Add(2) go c.handleMessages(pc) go c.handleErrors(pc.Partition(), pc.Errors()) } @@ -97,31 +92,33 @@ func (c *Consumer) Start() { // Close closes the Consumer and underlying sarama consumer func (c *Consumer) Close() error { - c.partitionMapLock.Lock() - for _, p := range c.partitionIDToState { - c.closePartition(p.partitionConsumer) - p.wg.Wait() - } - c.partitionMapLock.Unlock() - c.deadlockDetector.close() + // Close the internal consumer, which will close each partition consumers' message and error channels. c.logger.Info("Closing parent consumer") - return c.internalConsumer.Close() + err := c.internalConsumer.Close() + + c.logger.Debug("Closing deadlock detector") + c.deadlockDetector.close() + + c.logger.Debug("Waiting for messages and errors to be handled") + c.doneWg.Wait() + + return err } +// handleMessages handles incoming Kafka messages on a channel func (c *Consumer) handleMessages(pc sc.PartitionConsumer) { c.logger.Info("Starting message handler", zap.Int32("partition", pc.Partition())) c.partitionMapLock.Lock() c.partitionsHeld++ c.partitionsHeldGauge.Update(c.partitionsHeld) - wg := &c.partitionIDToState[pc.Partition()].wg c.partitionMapLock.Unlock() defer func() { c.closePartition(pc) - wg.Done() c.partitionMapLock.Lock() c.partitionsHeld-- c.partitionsHeldGauge.Update(c.partitionsHeld) c.partitionMapLock.Unlock() + c.doneWg.Done() }() msgMetrics := c.newMsgMetrics(pc.Partition()) @@ -165,12 +162,10 @@ func (c *Consumer) closePartition(partitionConsumer sc.PartitionConsumer) { c.logger.Info("Closed partition consumer", zap.Int32("partition", partitionConsumer.Partition())) } +// handleErrors handles incoming Kafka consumer errors on a channel func (c *Consumer) handleErrors(partition int32, errChan <-chan *sarama.ConsumerError) { c.logger.Info("Starting error handler", zap.Int32("partition", partition)) - c.partitionMapLock.Lock() - wg := &c.partitionIDToState[partition].wg - c.partitionMapLock.Unlock() - defer wg.Done() + defer c.doneWg.Done() errMetrics := c.newErrMetrics(partition) for err := range errChan { diff --git a/cmd/ingester/app/consumer/consumer_test.go b/cmd/ingester/app/consumer/consumer_test.go index 21a4dcd87ac..a903f5e6e90 100644 --- a/cmd/ingester/app/consumer/consumer_test.go +++ b/cmd/ingester/app/consumer/consumer_test.go @@ -68,7 +68,7 @@ func (s partitionConsumerWrapper) Topic() string { return s.topic } -func newSaramaClusterConsumer(saramaPartitionConsumer sarama.PartitionConsumer) *kmocks.Consumer { +func newSaramaClusterConsumer(saramaPartitionConsumer sarama.PartitionConsumer, mc *smocks.PartitionConsumer) *kmocks.Consumer { pcha := make(chan cluster.PartitionConsumer, 1) pcha <- &partitionConsumerWrapper{ topic: topic, @@ -77,27 +77,26 @@ func newSaramaClusterConsumer(saramaPartitionConsumer sarama.PartitionConsumer) } saramaClusterConsumer := &kmocks.Consumer{} saramaClusterConsumer.On("Partitions").Return((<-chan cluster.PartitionConsumer)(pcha)) - saramaClusterConsumer.On("Close").Return(nil) + saramaClusterConsumer.On("Close").Return(nil).Run(func(args mock.Arguments) { + mc.Close() + }) saramaClusterConsumer.On("MarkPartitionOffset", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) return saramaClusterConsumer } func newConsumer( + t *testing.T, metricsFactory metrics.Factory, topic string, processor processor.SpanProcessor, consumer consumer.Consumer) *Consumer { logger, _ := zap.NewDevelopment() - return &Consumer{ - metricsFactory: metricsFactory, - logger: logger, - internalConsumer: consumer, - partitionIDToState: make(map[int32]*consumerState), - partitionsHeldGauge: partitionsHeldGauge(metricsFactory), - deadlockDetector: newDeadlockDetector(metricsFactory, logger, time.Second), - - processorFactory: ProcessorFactory{ + consumerParams := Params{ + MetricsFactory: metricsFactory, + Logger: logger, + InternalConsumer: consumer, + ProcessorFactory: ProcessorFactory{ topic: topic, consumer: consumer, metricsFactory: metricsFactory, @@ -106,6 +105,10 @@ func newConsumer( parallelism: 1, }, } + + c, err := New(consumerParams) + require.NoError(t, err) + return c } func TestSaramaConsumerWrapper_MarkPartitionOffset(t *testing.T) { @@ -136,7 +139,7 @@ func TestSaramaConsumerWrapper_start_Messages(t *testing.T) { saramaPartitionConsumer, e := saramaConsumer.ConsumePartition(topic, partition, msgOffset) require.NoError(t, e) - undertest := newConsumer(localFactory, topic, mp, newSaramaClusterConsumer(saramaPartitionConsumer)) + undertest := newConsumer(t, localFactory, topic, mp, newSaramaClusterConsumer(saramaPartitionConsumer, mc)) undertest.partitionIDToState = map[int32]*consumerState{ partition: { @@ -202,7 +205,7 @@ func TestSaramaConsumerWrapper_start_Errors(t *testing.T) { saramaPartitionConsumer, e := saramaConsumer.ConsumePartition(topic, partition, msgOffset) require.NoError(t, e) - undertest := newConsumer(localFactory, topic, &pmocks.SpanProcessor{}, newSaramaClusterConsumer(saramaPartitionConsumer)) + undertest := newConsumer(t, localFactory, topic, &pmocks.SpanProcessor{}, newSaramaClusterConsumer(saramaPartitionConsumer, mc)) undertest.Start() mc.YieldError(errors.New("Daisy, Daisy")) @@ -238,7 +241,7 @@ func TestHandleClosePartition(t *testing.T) { saramaPartitionConsumer, e := saramaConsumer.ConsumePartition(topic, partition, msgOffset) require.NoError(t, e) - undertest := newConsumer(metricsFactory, topic, mp, newSaramaClusterConsumer(saramaPartitionConsumer)) + undertest := newConsumer(t, metricsFactory, topic, mp, newSaramaClusterConsumer(saramaPartitionConsumer, mc)) undertest.deadlockDetector = newDeadlockDetector(metricsFactory, undertest.logger, 200*time.Millisecond) undertest.Start() defer undertest.Close()