diff --git a/src/aggregator/server/m3msg/server.go b/src/aggregator/server/m3msg/server.go index bf98e4850e..467ff6e5c8 100644 --- a/src/aggregator/server/m3msg/server.go +++ b/src/aggregator/server/m3msg/server.go @@ -50,7 +50,7 @@ func NewServer( logger: opts.InstrumentOptions().Logger(), } } - handler := consumer.NewMessageHandler(newMessageProcessor, opts.ConsumerOptions()) + handler := consumer.NewMessageHandler(consumer.NewMessageProcessorFactory(newMessageProcessor), opts.ConsumerOptions()) return xserver.NewServer(address, handler, opts.ServerOptions()), nil } diff --git a/src/msg/consumer/consumer.go b/src/msg/consumer/consumer.go index b27dbc8b1b..dc0efcb0a3 100644 --- a/src/msg/consumer/consumer.go +++ b/src/msg/consumer/consumer.go @@ -69,7 +69,7 @@ func (l *listener) Accept() (Consumer, error) { return nil, err } - return newConsumer(conn, l.msgPool, l.opts, l.m, NewNoOpMessageProcessor), nil + return newConsumer(conn, l.msgPool, l.opts, l.m, NewNoOpMessageProcessor()), nil } type metrics struct { @@ -123,7 +123,7 @@ func newConsumer( mPool *messagePool, opts Options, m metrics, - newMessageProcessorFn NewMessageProcessorFn, + mp MessageProcessor, ) *consumer { var ( wOpts = xio.ResettableWriterOptions{ @@ -146,7 +146,7 @@ func newConsumer( closed: false, doneCh: make(chan struct{}), m: m, - messageProcessor: newMessageProcessorFn(), + messageProcessor: mp, } } @@ -262,7 +262,6 @@ func (c *consumer) Close() { close(c.doneCh) c.wg.Wait() c.conn.Close() - c.messageProcessor.Close() } type message struct { diff --git a/src/msg/consumer/handlers.go b/src/msg/consumer/handlers.go index 84b0f4ad6e..845efc1e00 100644 --- a/src/msg/consumer/handlers.go +++ b/src/msg/consumer/handlers.go @@ -31,26 +31,27 @@ import ( ) type messageHandler struct { - opts Options - mPool *messagePool - newMessageProcessorFn NewMessageProcessorFn - m metrics + opts Options + mPool *messagePool + mpFactory MessageProcessorFactory + m metrics } // NewMessageHandler creates a new server handler with messageFn. -func NewMessageHandler(newMessageProcessorFn NewMessageProcessorFn, opts Options) server.Handler { +func NewMessageHandler(mpFactory MessageProcessorFactory, opts Options) server.Handler { mPool := newMessagePool(opts.MessagePoolOptions()) mPool.Init() return &messageHandler{ - newMessageProcessorFn: newMessageProcessorFn, - opts: opts, - mPool: mPool, - m: newConsumerMetrics(opts.InstrumentOptions().MetricsScope()), + mpFactory: mpFactory, + opts: opts, + mPool: mPool, + m: newConsumerMetrics(opts.InstrumentOptions().MetricsScope()), } } func (h *messageHandler) Handle(conn net.Conn) { - c := newConsumer(conn, h.mPool, h.opts, h.m, h.newMessageProcessorFn) + mp := h.mpFactory.Create() + c := newConsumer(conn, h.mPool, h.opts, h.m, mp) c.Init() var ( msgErr error @@ -68,7 +69,10 @@ func (h *messageHandler) Handle(conn net.Conn) { if msgErr != nil && msgErr != io.EOF { h.opts.InstrumentOptions().Logger().With(zap.Error(msgErr)).Error("could not read message from consumer") } + mp.Close() c.Close() } -func (h *messageHandler) Close() {} +func (h *messageHandler) Close() { + h.mpFactory.Close() +} diff --git a/src/msg/consumer/handlers_test.go b/src/msg/consumer/handlers_test.go index 1c7f9d99c5..454de168a9 100644 --- a/src/msg/consumer/handlers_test.go +++ b/src/msg/consumer/handlers_test.go @@ -22,6 +22,7 @@ package consumer import ( "net" + "sort" "sync" "testing" @@ -35,12 +36,13 @@ import ( "github.com/stretchr/testify/require" ) -func TestServerWithMessageFn(t *testing.T) { +func TestServerWithSingletonMessageProcessor(t *testing.T) { defer leaktest.Check(t)() var ( data []string wg sync.WaitGroup + mu sync.Mutex ) ctrl := gomock.NewController(t) @@ -49,11 +51,13 @@ func TestServerWithMessageFn(t *testing.T) { p := NewMockMessageProcessor(ctrl) p.EXPECT().Process(gomock.Any()).Do( func(m Message) { + mu.Lock() data = append(data, string(m.Bytes())) + mu.Unlock() m.Ack() wg.Done() }, - ).Times(2) + ).Times(3) // Set a large ack buffer size to make sure the background go routine // can flush it. opts := testOptions().SetAckBufferSize(100) @@ -61,32 +65,43 @@ func TestServerWithMessageFn(t *testing.T) { require.NoError(t, err) s := server.NewServer("a", NewMessageHandler(SingletonMessageProcessor(p), opts), server.NewOptions()) - s.Serve(l) + defer s.Close() + require.NoError(t, s.Serve(l)) - conn, err := net.Dial("tcp", l.Addr().String()) + conn1, err := net.Dial("tcp", l.Addr().String()) + require.NoError(t, err) + conn2, err := net.Dial("tcp", l.Addr().String()) require.NoError(t, err) - wg.Add(1) - err = produce(conn, &testMsg1) + wg.Add(3) + err = produce(conn1, &testMsg1) require.NoError(t, err) - wg.Add(1) - err = produce(conn, &testMsg2) + err = produce(conn1, &testMsg2) + require.NoError(t, err) + err = produce(conn2, &testMsg2) require.NoError(t, err) wg.Wait() - require.Equal(t, string(testMsg1.Value), data[0]) + sort.Strings(data) + require.Equal(t, string(testMsg2.Value), data[0]) require.Equal(t, string(testMsg2.Value), data[1]) + require.Equal(t, string(testMsg1.Value), data[2]) var ack msgpb.Ack - testDecoder := proto.NewDecoder(conn, opts.DecoderOptions(), 10) + testDecoder := proto.NewDecoder(conn1, opts.DecoderOptions(), 10) + err = testDecoder.Decode(&ack) + require.NoError(t, err) + testDecoder = proto.NewDecoder(conn2, opts.DecoderOptions(), 10) err = testDecoder.Decode(&ack) require.NoError(t, err) - require.Equal(t, 2, len(ack.Metadata)) + require.Equal(t, 3, len(ack.Metadata)) + sort.Slice(ack.Metadata, func(i, j int) bool { + return ack.Metadata[i].Id < ack.Metadata[j].Id + }) require.Equal(t, testMsg1.Metadata, ack.Metadata[0]) require.Equal(t, testMsg2.Metadata, ack.Metadata[1]) - + require.Equal(t, testMsg2.Metadata, ack.Metadata[2]) p.EXPECT().Close() - s.Close() } func TestServerMessageDifferentConnections(t *testing.T) { @@ -126,7 +141,8 @@ func TestServerMessageDifferentConnections(t *testing.T) { return mp2 } - s := server.NewServer("a", NewMessageHandler(newMessageProcessor, opts), server.NewOptions()) + s := server.NewServer("a", + NewMessageHandler(NewMessageProcessorFactory(newMessageProcessor), opts), server.NewOptions()) require.NoError(t, err) require.NoError(t, s.Serve(l)) diff --git a/src/msg/consumer/types.go b/src/msg/consumer/types.go index cf556eb678..5f01b84af1 100644 --- a/src/msg/consumer/types.go +++ b/src/msg/consumer/types.go @@ -132,17 +132,59 @@ type MessageProcessor interface { Close() } -// NewMessageProcessorFn creates a new MessageProcessor scoped to a single connection. Messages are processed serially -// in a connection. -type NewMessageProcessorFn func() MessageProcessor +// MessageProcessorFactory creates MessageProcessors. +type MessageProcessorFactory interface { + // Create returns a MessageProcessor. + Create() MessageProcessor + // Close the factory. + Close() +} + +// SingletonMessageProcessor returns a MessageProcessorFactory that shares the same MessageProcessor for all users. The +// MessageProcessor is closed when the factory is closed. +func SingletonMessageProcessor(mp MessageProcessor) MessageProcessorFactory { + return &singletonMessageProcessorFactory{mp: mp, noClose: &noCloseMessageProcessor{mp: mp}} +} + +type singletonMessageProcessorFactory struct { + mp MessageProcessor + noClose MessageProcessor +} + +func (s singletonMessageProcessorFactory) Create() MessageProcessor { + return s.noClose +} + +func (s singletonMessageProcessorFactory) Close() { + s.mp.Close() +} -// SingletonMessageProcessor uses the same MessageProcessor for all connections. -func SingletonMessageProcessor(p MessageProcessor) NewMessageProcessorFn { - return func() MessageProcessor { - return p - } +type noCloseMessageProcessor struct { + mp MessageProcessor } +func (n noCloseMessageProcessor) Process(m Message) { + n.mp.Process(m) +} + +func (n noCloseMessageProcessor) Close() {} + +// NewMessageProcessorFactory returns a MessageProcessorFactory that creates a new MessageProcessor for every call to +// Create. +func NewMessageProcessorFactory(fn func() MessageProcessor) MessageProcessorFactory { + return &messageProcessorFactory{fn: fn} +} + +type messageProcessorFactory struct { + fn func() MessageProcessor +} + +func (m messageProcessorFactory) Create() MessageProcessor { + return m.fn() +} + +func (m messageProcessorFactory) Close() {} + // NewNoOpMessageProcessor creates a new MessageProcessor that does nothing. func NewNoOpMessageProcessor() MessageProcessor { return &noOpMessageProcessor{}