diff --git a/src/aggregator/server/m3msg/server.go b/src/aggregator/server/m3msg/server.go index 924254d313..bf98e4850e 100644 --- a/src/aggregator/server/m3msg/server.go +++ b/src/aggregator/server/m3msg/server.go @@ -21,9 +21,7 @@ package m3msg import ( - "errors" "fmt" - "io" "github.com/m3db/m3/src/aggregator/aggregator" "github.com/m3db/m3/src/metrics/encoding" @@ -35,11 +33,6 @@ import ( "go.uber.org/zap" ) -type server struct { - aggregator aggregator.Aggregator - logger *zap.Logger -} - // NewServer creates a new M3Msg server. func NewServer( address string, @@ -49,49 +42,43 @@ func NewServer( if err := opts.Validate(); err != nil { return nil, err } - - s := &server{ - aggregator: aggregator, - logger: opts.InstrumentOptions().Logger(), + newMessageProcessor := func() consumer.MessageProcessor { + // construct a new messageProcessor per consumer so the internal protos can be reused across messages on the + // same connection. + return &messageProcessor{ + aggregator: aggregator, + logger: opts.InstrumentOptions().Logger(), + } } - - handler := consumer.NewConsumerHandler(s.Consume, opts.ConsumerOptions()) + handler := consumer.NewMessageHandler(newMessageProcessor, opts.ConsumerOptions()) return xserver.NewServer(address, handler, opts.ServerOptions()), nil } -func (s *server) Consume(c consumer.Consumer) { - var ( - pb = &metricpb.MetricWithMetadatas{} - union = &encoding.UnaggregatedMessageUnion{} - ) - for { - msg, err := c.Message() - if err != nil { - if !errors.Is(err, io.EOF) { - s.logger.Error("could not read message", zap.Error(err)) - } - break - } +type messageProcessor struct { + pb metricpb.MetricWithMetadatas + union encoding.UnaggregatedMessageUnion + aggregator aggregator.Aggregator + logger *zap.Logger +} - // Reset and reuse the protobuf message for unpacking. - protobuf.ReuseMetricWithMetadatasProto(pb) - if err = s.handleMessage(pb, union, msg); err != nil { - s.logger.Error("could not process message", - zap.Error(err), - zap.Uint64("shard", msg.ShardID()), - zap.String("proto", pb.String())) - } +func (m *messageProcessor) Process(msg consumer.Message) { + if err := m.handleMessage(&m.pb, &m.union, msg); err != nil { + m.logger.Error("could not process message", + zap.Error(err), + zap.Uint64("shard", msg.ShardID()), + zap.String("proto", m.pb.String())) } - c.Close() } -func (s *server) handleMessage( +func (m *messageProcessor) handleMessage( pb *metricpb.MetricWithMetadatas, union *encoding.UnaggregatedMessageUnion, msg consumer.Message, ) error { defer msg.Ack() + // Reset and reuse the protobuf message for unpacking. + protobuf.ReuseMetricWithMetadatasProto(&m.pb) // Unmarshal the message. if err := pb.Unmarshal(msg.Bytes()); err != nil { return err @@ -104,27 +91,27 @@ func (s *server) handleMessage( return err } u := union.CounterWithMetadatas.ToUnion() - return s.aggregator.AddUntimed(u, union.CounterWithMetadatas.StagedMetadatas) + return m.aggregator.AddUntimed(u, union.CounterWithMetadatas.StagedMetadatas) case metricpb.MetricWithMetadatas_BATCH_TIMER_WITH_METADATAS: err := union.BatchTimerWithMetadatas.FromProto(pb.BatchTimerWithMetadatas) if err != nil { return err } u := union.BatchTimerWithMetadatas.ToUnion() - return s.aggregator.AddUntimed(u, union.BatchTimerWithMetadatas.StagedMetadatas) + return m.aggregator.AddUntimed(u, union.BatchTimerWithMetadatas.StagedMetadatas) case metricpb.MetricWithMetadatas_GAUGE_WITH_METADATAS: err := union.GaugeWithMetadatas.FromProto(pb.GaugeWithMetadatas) if err != nil { return err } u := union.GaugeWithMetadatas.ToUnion() - return s.aggregator.AddUntimed(u, union.GaugeWithMetadatas.StagedMetadatas) + return m.aggregator.AddUntimed(u, union.GaugeWithMetadatas.StagedMetadatas) case metricpb.MetricWithMetadatas_FORWARDED_METRIC_WITH_METADATA: err := union.ForwardedMetricWithMetadata.FromProto(pb.ForwardedMetricWithMetadata) if err != nil { return err } - return s.aggregator.AddForwarded( + return m.aggregator.AddForwarded( union.ForwardedMetricWithMetadata.ForwardedMetric, union.ForwardedMetricWithMetadata.ForwardMetadata) case metricpb.MetricWithMetadatas_TIMED_METRIC_WITH_METADATA: @@ -132,7 +119,7 @@ func (s *server) handleMessage( if err != nil { return err } - return s.aggregator.AddTimed( + return m.aggregator.AddTimed( union.TimedMetricWithMetadata.Metric, union.TimedMetricWithMetadata.TimedMetadata) case metricpb.MetricWithMetadatas_TIMED_METRIC_WITH_METADATAS: @@ -140,10 +127,12 @@ func (s *server) handleMessage( if err != nil { return err } - return s.aggregator.AddTimedWithStagedMetadatas( + return m.aggregator.AddTimedWithStagedMetadatas( union.TimedMetricWithMetadatas.Metric, union.TimedMetricWithMetadatas.StagedMetadatas) default: return fmt.Errorf("unrecognized message type: %v", pb.Type) } } + +func (m *messageProcessor) Close() {} diff --git a/src/cmd/services/m3coordinator/server/m3msg/config.go b/src/cmd/services/m3coordinator/server/m3msg/config.go index 953f62a2d0..a37119cc97 100644 --- a/src/cmd/services/m3coordinator/server/m3msg/config.go +++ b/src/cmd/services/m3coordinator/server/m3msg/config.go @@ -86,7 +86,7 @@ func (c handlerConfiguration) newHandler( ProtobufDecoderPoolOptions: c.ProtobufDecoderPool.NewObjectPoolOptions(iOpts), BlockholePolicies: c.BlackholePolicies, }) - return consumer.NewMessageHandler(p, cOpts), nil + return consumer.NewMessageHandler(consumer.SingletonMessageProcessor(p), cOpts), nil } // NewOptions creates handler options. diff --git a/src/cmd/services/m3coordinator/server/m3msg/protobuf_handler_test.go b/src/cmd/services/m3coordinator/server/m3msg/protobuf_handler_test.go index f11de18ff5..b15cf8af9b 100644 --- a/src/cmd/services/m3coordinator/server/m3msg/protobuf_handler_test.go +++ b/src/cmd/services/m3coordinator/server/m3msg/protobuf_handler_test.go @@ -67,7 +67,7 @@ func TestM3MsgServerWithProtobufHandler(t *testing.T) { s := server.NewServer( "a", - consumer.NewMessageHandler(newProtobufProcessor(hOpts), opts), + consumer.NewMessageHandler(consumer.SingletonMessageProcessor(newProtobufProcessor(hOpts)), opts), server.NewOptions(), ) s.Serve(l) @@ -150,7 +150,7 @@ func TestM3MsgServerWithProtobufHandler_Blackhole(t *testing.T) { s := server.NewServer( "a", - consumer.NewMessageHandler(newProtobufProcessor(hOpts), opts), + consumer.NewMessageHandler(consumer.SingletonMessageProcessor(newProtobufProcessor(hOpts)), opts), server.NewOptions(), ) s.Serve(l) diff --git a/src/msg/consumer/consumer.go b/src/msg/consumer/consumer.go index 50a398488b..f880b67ec7 100644 --- a/src/msg/consumer/consumer.go +++ b/src/msg/consumer/consumer.go @@ -66,7 +66,7 @@ func (l *listener) Accept() (Consumer, error) { return nil, err } - return newConsumer(conn, l.msgPool, l.opts, l.m), nil + return newConsumer(conn, l.msgPool, l.opts, l.m, NewNoOpMessageProcessor), nil } type metrics struct { @@ -97,11 +97,12 @@ type consumer struct { w xio.ResettableWriter conn net.Conn - ackPb msgpb.Ack - closed bool - doneCh chan struct{} - wg sync.WaitGroup - m metrics + ackPb msgpb.Ack + closed bool + doneCh chan struct{} + wg sync.WaitGroup + m metrics + messageProcessor MessageProcessor } func newConsumer( @@ -109,6 +110,7 @@ func newConsumer( mPool *messagePool, opts Options, m metrics, + newMessageProcessorFn NewMessageProcessorFn, ) *consumer { var ( wOpts = xio.ResettableWriterOptions{ @@ -126,11 +128,12 @@ func newConsumer( decoder: proto.NewDecoder( conn, opts.DecoderOptions(), opts.ConnectionReadBufferSize(), ), - w: writerFn(newConnWithTimeout(conn, opts.ConnectionWriteTimeout(), time.Now), wOpts), - conn: conn, - closed: false, - doneCh: make(chan struct{}), - m: m, + w: writerFn(newConnWithTimeout(conn, opts.ConnectionWriteTimeout(), time.Now), wOpts), + conn: conn, + closed: false, + doneCh: make(chan struct{}), + m: m, + messageProcessor: newMessageProcessorFn(), } } @@ -141,6 +144,9 @@ func (c *consumer) Init() { c.wg.Done() }() } +func (c *consumer) process(m Message) { + c.messageProcessor.Process(m) +} func (c *consumer) Message() (Message, error) { m := c.mPool.Get() @@ -230,6 +236,7 @@ func (c *consumer) Close() { close(c.doneCh) c.wg.Wait() c.conn.Close() + c.messageProcessor.Close() } type message struct { diff --git a/src/msg/consumer/handlers.go b/src/msg/consumer/handlers.go index 353799e259..572767a078 100644 --- a/src/msg/consumer/handlers.go +++ b/src/msg/consumer/handlers.go @@ -29,54 +29,27 @@ import ( "go.uber.org/zap" ) -type consumerHandler struct { - opts Options - mPool *messagePool - consumeFn ConsumeFn - m metrics -} - -// NewConsumerHandler creates a new server handler with consumerFn. -func NewConsumerHandler(consumeFn ConsumeFn, opts Options) server.Handler { - mPool := newMessagePool(opts.MessagePoolOptions()) - mPool.Init() - return &consumerHandler{ - consumeFn: consumeFn, - opts: opts, - mPool: mPool, - m: newConsumerMetrics(opts.InstrumentOptions().MetricsScope()), - } -} - -func (h *consumerHandler) Handle(conn net.Conn) { - c := newConsumer(conn, h.mPool, h.opts, h.m) - c.Init() - h.consumeFn(c) -} - -func (h *consumerHandler) Close() {} - type messageHandler struct { - opts Options - mPool *messagePool - mp MessageProcessor - m metrics + opts Options + mPool *messagePool + newMessageProcessorFn NewMessageProcessorFn + m metrics } // NewMessageHandler creates a new server handler with messageFn. -func NewMessageHandler(mp MessageProcessor, opts Options) server.Handler { +func NewMessageHandler(newMessageProcessorFn NewMessageProcessorFn, opts Options) server.Handler { mPool := newMessagePool(opts.MessagePoolOptions()) mPool.Init() return &messageHandler{ - mp: mp, - opts: opts, - mPool: mPool, - m: newConsumerMetrics(opts.InstrumentOptions().MetricsScope()), + newMessageProcessorFn: newMessageProcessorFn, + opts: opts, + mPool: mPool, + m: newConsumerMetrics(opts.InstrumentOptions().MetricsScope()), } } func (h *messageHandler) Handle(conn net.Conn) { - c := newConsumer(conn, h.mPool, h.opts, h.m) + c := newConsumer(conn, h.mPool, h.opts, h.m, h.newMessageProcessorFn) c.Init() var ( msgErr error @@ -87,7 +60,7 @@ func (h *messageHandler) Handle(conn net.Conn) { if msgErr != nil { break } - h.mp.Process(msg) + c.process(msg) } if msgErr != nil && msgErr != io.EOF { h.opts.InstrumentOptions().Logger().With(zap.Error(msgErr)).Error("could not read message from consumer") @@ -95,4 +68,4 @@ func (h *messageHandler) Handle(conn net.Conn) { c.Close() } -func (h *messageHandler) Close() { h.mp.Close() } +func (h *messageHandler) Close() {} diff --git a/src/msg/consumer/handlers_test.go b/src/msg/consumer/handlers_test.go index 63ae802c3c..1c7f9d99c5 100644 --- a/src/msg/consumer/handlers_test.go +++ b/src/msg/consumer/handlers_test.go @@ -28,6 +28,7 @@ import ( "github.com/m3db/m3/src/msg/generated/proto/msgpb" "github.com/m3db/m3/src/msg/protocol/proto" "github.com/m3db/m3/src/x/server" + xtest "github.com/m3db/m3/src/x/test" "github.com/fortytw2/leaktest" "github.com/golang/mock/gomock" @@ -59,7 +60,7 @@ func TestServerWithMessageFn(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) - s := server.NewServer("a", NewMessageHandler(p, opts), server.NewOptions()) + s := server.NewServer("a", NewMessageHandler(SingletonMessageProcessor(p), opts), server.NewOptions()) s.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) @@ -88,57 +89,57 @@ func TestServerWithMessageFn(t *testing.T) { s.Close() } -func TestServerWithConsumeFn(t *testing.T) { +func TestServerMessageDifferentConnections(t *testing.T) { defer leaktest.Check(t)() - var ( - count = 0 - bytes []byte - closed bool - wg sync.WaitGroup - ) - consumeFn := func(c Consumer) { - for { - count++ - m, err := c.Message() - if err != nil { - break - } - bytes = m.Bytes() - m.Ack() - wg.Done() - } - c.Close() - closed = true - } + ctrl := xtest.NewController(t) + defer ctrl.Finish() l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) + var wg sync.WaitGroup + wg.Add(2) + handleMessage := func(m Message) { + wg.Done() + } + + mp1 := NewMockMessageProcessor(ctrl) + mp2 := NewMockMessageProcessor(ctrl) + mp1.EXPECT().Process(gomock.Any()).Do(handleMessage) + mp1.EXPECT().Close() + mp2.EXPECT().Process(gomock.Any()).Do(handleMessage) + mp2.EXPECT().Close() + // Set a large ack buffer size to make sure the background go routine // can flush it. opts := testOptions().SetAckBufferSize(100) - s := server.NewServer("a", NewConsumerHandler(consumeFn, opts), server.NewOptions()) - require.NoError(t, err) - s.Serve(l) + first := true + var mu sync.Mutex + newMessageProcessor := func() MessageProcessor { + mu.Lock() + defer mu.Unlock() + if first { + first = false + return mp1 + } + return mp2 + } - conn, err := net.Dial("tcp", l.Addr().String()) + s := server.NewServer("a", NewMessageHandler(newMessageProcessor, opts), server.NewOptions()) require.NoError(t, err) + require.NoError(t, s.Serve(l)) - wg.Add(1) - err = produce(conn, &testMsg1) + conn1, err := net.Dial("tcp", l.Addr().String()) + require.NoError(t, err) + conn2, err := net.Dial("tcp", l.Addr().String()) require.NoError(t, err) - wg.Wait() - require.Equal(t, testMsg1.Value, bytes) - - var ack msgpb.Ack - testDecoder := proto.NewDecoder(conn, opts.DecoderOptions(), 10) - err = testDecoder.Decode(&ack) + err = produce(conn1, &testMsg1) + require.NoError(t, err) + err = produce(conn2, &testMsg1) require.NoError(t, err) - require.Equal(t, 1, len(ack.Metadata)) - require.Equal(t, testMsg1.Metadata, ack.Metadata[0]) + wg.Wait() s.Close() - require.True(t, closed) } diff --git a/src/msg/consumer/types.go b/src/msg/consumer/types.go index 79b7c243b1..bdd1ea3f06 100644 --- a/src/msg/consumer/types.go +++ b/src/msg/consumer/types.go @@ -129,6 +129,28 @@ type MessageProcessor interface { Close() } +// NewMessageProcessorFn creates a new MessageProcessor scoped to a single connection. Messages are processed serially +// in a connection. +type NewMessageProcessorFn func() MessageProcessor + +// SingletonMessageProcessor uses the same MessageProcessor for all connections. +func SingletonMessageProcessor(p MessageProcessor) NewMessageProcessorFn { + return func() MessageProcessor { + return p + } +} + +// NewNoOpMessageProcessor creates a new MessageProcessor that does nothing. +func NewNoOpMessageProcessor() MessageProcessor { + return &noOpMessageProcessor{} +} + +type noOpMessageProcessor struct{} + +func (n noOpMessageProcessor) Process(Message) {} + +func (n noOpMessageProcessor) Close() {} + // ConsumeFn processes the consumer. This is useful when user want to reuse // resource across messages received on the same consumer or have finer level // control on how to read messages from consumer.