diff --git a/src/msg/consumer/consumer.go b/src/msg/consumer/consumer.go index 8703f0b6fa..b27dbc8b1b 100644 --- a/src/msg/consumer/consumer.go +++ b/src/msg/consumer/consumer.go @@ -25,6 +25,8 @@ import ( "sync" "time" + "go.uber.org/zap" + "github.com/m3db/m3/src/msg/generated/proto/msgpb" "github.com/m3db/m3/src/msg/protocol/proto" "github.com/m3db/m3/src/x/clock" @@ -188,9 +190,7 @@ func (c *consumer) tryAck(m msgpb.Metadata) { c.Unlock() return } - if err := c.encodeAckWithLock(ackLen); err != nil { - c.conn.Close() - } + c.trySendAcksWithLock(ackLen) c.Unlock() } @@ -212,30 +212,42 @@ func (c *consumer) ackUntilClose() { func (c *consumer) tryAckAndFlush() { c.Lock() if ackLen := len(c.ackPb.Metadata); ackLen > 0 { - c.encodeAckWithLock(ackLen) + c.trySendAcksWithLock(ackLen) } c.w.Flush() c.Unlock() } -func (c *consumer) encodeAckWithLock(ackLen int) error { +// if acks fail to send the client will retry sending the messages. +func (c *consumer) trySendAcksWithLock(ackLen int) { err := c.encoder.Encode(&c.ackPb) + log := c.opts.InstrumentOptions().Logger() c.ackPb.Metadata = c.ackPb.Metadata[:0] if err != nil { c.m.ackEncodeError.Inc(1) - return err + log.Error("failed to encode ack. client will retry sending message.", zap.Error(err)) + return } _, err = c.w.Write(c.encoder.Bytes()) if err != nil { c.m.ackWriteError.Inc(1) - return err + log.Error("failed to write ack. client will retry sending message.", zap.Error(err)) + c.tryCloseConn() + return } if err := c.w.Flush(); err != nil { c.m.ackWriteError.Inc(1) - return err + log.Error("failed to flush ack. client will retry sending message.", zap.Error(err)) + c.tryCloseConn() + return } c.m.ackSent.Inc(int64(ackLen)) - return nil +} + +func (c *consumer) tryCloseConn() { + if err := c.conn.Close(); err != nil { + c.opts.InstrumentOptions().Logger().Error("failed to close connection.", zap.Error(err)) + } } func (c *consumer) Close() { diff --git a/src/msg/consumer/consumer_test.go b/src/msg/consumer/consumer_test.go index 5889fb3609..c3cc4253fd 100644 --- a/src/msg/consumer/consumer_test.go +++ b/src/msg/consumer/consumer_test.go @@ -141,7 +141,7 @@ func TestConsumerAckReusedMessage(t *testing.T) { require.Equal(t, testMsg2.Metadata, cc.ackPb.Metadata[1]) } -func TestConsumerAckError(t *testing.T) { +func TestConsumerWriteAckError(t *testing.T) { defer leaktest.Check(t)() opts := testOptions() @@ -169,13 +169,62 @@ func TestConsumerAckError(t *testing.T) { require.NoError(t, err) require.Equal(t, testMsg1.Value, m.Bytes()) - mockEncoder.EXPECT().Encode(gomock.Any()).Return(errors.New("mock encode err")) + mockEncoder.EXPECT().Encode(gomock.Any()) + mockEncoder.EXPECT().Bytes().Return([]byte("foo")) + // force a bad write + cc.w.Reset(&badWriter{}) m.Ack() + // connection can no longer be used. _, err = cc.Message() require.Error(t, err) } +type badWriter struct{} + +func (w *badWriter) Write([]byte) (int, error) { + return 0, errors.New("fail") +} + +func TestConsumerDecodeAckError(t *testing.T) { + defer leaktest.Check(t)() + + opts := testOptions() + l, err := NewListener("127.0.0.1:0", opts) + require.NoError(t, err) + defer func() { + require.NoError(t, l.Close()) + }() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + conn, err := net.Dial("tcp", l.Addr().String()) + require.NoError(t, err) + + c, err := l.Accept() + require.NoError(t, err) + + mockEncoder := proto.NewMockEncoder(ctrl) + cc := c.(*consumer) + cc.encoder = mockEncoder + + err = produce(conn, &testMsg1) + require.NoError(t, err) + + m, err := cc.Message() + require.NoError(t, err) + require.Equal(t, testMsg1.Value, m.Bytes()) + + mockEncoder.EXPECT().Encode(gomock.Any()) + mockEncoder.EXPECT().Bytes() + m.Ack() + + // can still use the connection after failing to decode an ack. + err = produce(conn, &testMsg1) + require.NoError(t, err) +} + func TestConsumerMessageError(t *testing.T) { defer leaktest.Check(t)()