diff --git a/examples/reliable/reliable_client.go b/examples/reliable/reliable_client.go index 230431b4..fa154476 100644 --- a/examples/reliable/reliable_client.go +++ b/examples/reliable/reliable_client.go @@ -11,7 +11,6 @@ import ( "github.com/rabbitmq/rabbitmq-stream-go-client/pkg/amqp" "github.com/rabbitmq/rabbitmq-stream-go-client/pkg/ha" - "github.com/rabbitmq/rabbitmq-stream-go-client/pkg/logs" "github.com/rabbitmq/rabbitmq-stream-go-client/pkg/message" "github.com/rabbitmq/rabbitmq-stream-go-client/pkg/stream" ) @@ -33,10 +32,10 @@ var reSent int32 func main() { // Tune the parameters to test the reliability - const messagesToSend = 5_000_000 - const numberOfProducers = 2 + const messagesToSend = 10_000_000 + const numberOfProducers = 4 const concurrentProducers = 2 - const numberOfConsumers = 2 + const numberOfConsumers = 4 const sendDelay = 1 * time.Millisecond const delayEachMessages = 200 const maxProducersPerClient = 4 @@ -44,7 +43,7 @@ func main() { // reader := bufio.NewReader(os.Stdin) - stream.SetLevelInfo(logs.DEBUG) + //stream.SetLevelInfo(logs.DEBUG) fmt.Println("Reliable Producer/Consumer example") fmt.Println("Connecting to RabbitMQ streaming ...") diff --git a/perfTest/cmd/commands.go b/perfTest/cmd/commands.go index 9cb263c8..9fceccc8 100644 --- a/perfTest/cmd/commands.go +++ b/perfTest/cmd/commands.go @@ -57,7 +57,7 @@ func init() { func setupCli(baseCmd *cobra.Command) { baseCmd.PersistentFlags().StringSliceVarP(&rabbitmqBrokerUrl, "uris", "", []string{stream.LocalhostUriConnection}, "Broker URLs") baseCmd.PersistentFlags().IntVarP(&publishers, "publishers", "", 1, "Number of Publishers") - baseCmd.PersistentFlags().IntVarP(&batchSize, "batch-size", "", 100, "Batch Size, from 1 to 200") + baseCmd.PersistentFlags().IntVarP(&batchSize, "batch-size", "", 200, "Batch Size, from 1 to 300") baseCmd.PersistentFlags().IntVarP(&subEntrySize, "sub-entry-size", "", 1, "SubEntry size, default 1. > 1 Enable the subEntryBatch") baseCmd.PersistentFlags().StringVarP(&compression, "compression", "", "", "Compression for sub batching, none,gzip,lz4,snappy,zstd") baseCmd.PersistentFlags().IntVarP(&consumers, "consumers", "", 1, "Number of Consumers") diff --git a/perfTest/cmd/silent.go b/perfTest/cmd/silent.go index 2467736e..5fa5a709 100644 --- a/perfTest/cmd/silent.go +++ b/perfTest/cmd/silent.go @@ -150,8 +150,8 @@ func startSimulation() error { stream.SetLevelInfo(logs.DEBUG) } - if batchSize < 1 || batchSize > 200 { - logError("Invalid batchSize, must be from 1 to 200, value:%d", batchSize) + if batchSize < 1 || batchSize > 300 { + logError("Invalid batchSize, must be from 1 to 300, value:%d", batchSize) os.Exit(1) } diff --git a/pkg/stream/aggregation.go b/pkg/stream/aggregation.go index 29eaa381..868809d5 100644 --- a/pkg/stream/aggregation.go +++ b/pkg/stream/aggregation.go @@ -52,7 +52,7 @@ func (compression Compression) Lz4() Compression { } type subEntry struct { - messages []messageSequence + messages []*messageSequence publishingId int64 // need to store the publishingId useful in case of aggregation unCompressedSize int diff --git a/pkg/stream/aggregation_test.go b/pkg/stream/aggregation_test.go index cad22bcf..e4c16375 100644 --- a/pkg/stream/aggregation_test.go +++ b/pkg/stream/aggregation_test.go @@ -17,7 +17,7 @@ var _ = Describe("Compression algorithms", func() { messagePayload[i] = 99 } - message := messageSequence{ + message := &messageSequence{ messageBytes: messagePayload, unCompressedSize: len(messagePayload), publishingId: 0, @@ -25,7 +25,7 @@ var _ = Describe("Compression algorithms", func() { entries = &subEntries{ items: []*subEntry{{ - messages: []messageSequence{message}, + messages: []*messageSequence{message}, publishingId: 0, unCompressedSize: len(messagePayload) + 4, sizeInBytes: 0, diff --git a/pkg/stream/coordinator.go b/pkg/stream/coordinator.go index 0d90c64b..e3015f77 100644 --- a/pkg/stream/coordinator.go +++ b/pkg/stream/coordinator.go @@ -65,13 +65,13 @@ func (coordinator *Coordinator) NewProducer( } var producer = &Producer{id: lastId, options: parameters, - mutex: &sync.Mutex{}, + mutex: &sync.RWMutex{}, mutexPending: &sync.Mutex{}, unConfirmedMessages: map[int64]*ConfirmationStatus{}, status: open, messageSequenceCh: make(chan messageSequence, size), pendingMessages: pendingMessagesSequence{ - messages: make([]messageSequence, 0), + messages: make([]*messageSequence, 0), size: initBufferPublishSize, }} coordinator.producers[lastId] = producer diff --git a/pkg/stream/producer.go b/pkg/stream/producer.go index 3cdc080e..d4d61f37 100644 --- a/pkg/stream/producer.go +++ b/pkg/stream/producer.go @@ -51,7 +51,7 @@ func (cs *ConfirmationStatus) GetErrorCode() uint16 { } type pendingMessagesSequence struct { - messages []messageSequence + messages []*messageSequence size int } @@ -60,6 +60,7 @@ type messageSequence struct { unCompressedSize int publishingId int64 filterValue string + refMessage *message.StreamMessage } type Producer struct { @@ -68,7 +69,7 @@ type Producer struct { onClose onInternalClose unConfirmedMessages map[int64]*ConfirmationStatus sequence int64 - mutex *sync.Mutex + mutex *sync.RWMutex mutexPending *sync.Mutex publishConfirm chan []*ConfirmationStatus closeHandler chan Event @@ -170,11 +171,27 @@ func NewProducerOptions() *ProducerOptions { } func (producer *Producer) GetUnConfirmed() map[int64]*ConfirmationStatus { - producer.mutex.Lock() - defer producer.mutex.Unlock() + producer.mutex.RLock() + defer producer.mutex.RUnlock() return producer.unConfirmedMessages } +func (producer *Producer) addUnConfirmedSequences(message []*messageSequence, producerID uint8) { + producer.mutex.Lock() + defer producer.mutex.Unlock() + + for _, msg := range message { + producer.unConfirmedMessages[msg.publishingId] = + &ConfirmationStatus{ + inserted: time.Now(), + message: *msg.refMessage, + producerID: producerID, + publishingId: msg.publishingId, + confirmed: false, + } + } + +} func (producer *Producer) addUnConfirmed(sequence int64, message message.StreamMessage, producerID uint8) { producer.mutex.Lock() defer producer.mutex.Unlock() @@ -191,6 +208,18 @@ func (po *ProducerOptions) isSubEntriesBatching() bool { return po.SubEntrySize > 1 } +func (producer *Producer) removeFromConfirmationStatus(status []*ConfirmationStatus) { + producer.mutex.Lock() + defer producer.mutex.Unlock() + + for _, msg := range status { + delete(producer.unConfirmedMessages, msg.publishingId) + for _, linked := range msg.linkedTo { + delete(producer.unConfirmedMessages, linked.publishingId) + } + } +} + func (producer *Producer) removeUnConfirmed(sequence int64) { producer.mutex.Lock() defer producer.mutex.Unlock() @@ -210,13 +239,13 @@ func (producer *Producer) lenPendingMessages() int { } func (producer *Producer) getUnConfirmed(sequence int64) *ConfirmationStatus { - producer.mutex.Lock() - defer producer.mutex.Unlock() + producer.mutex.RLock() + defer producer.mutex.RUnlock() return producer.unConfirmedMessages[sequence] } func (producer *Producer) NotifyPublishConfirmation() ChannelPublishConfirm { - ch := make(chan []*ConfirmationStatus) + ch := make(chan []*ConfirmationStatus, 1) producer.publishConfirm = ch return ch } @@ -263,19 +292,26 @@ func (producer *Producer) startUnconfirmedMessagesTimeOutTask() { go func() { for producer.getStatus() == open { time.Sleep(2 * time.Second) - producer.mutex.Lock() + toRemove := make([]*ConfirmationStatus, 0) + // check the unconfirmed messages and remove the one that are expired + // use the RLock to avoid blocking the producer + producer.mutex.RLock() for _, msg := range producer.unConfirmedMessages { if time.Since(msg.inserted) > producer.options.ConfirmationTimeOut { msg.err = ConfirmationTimoutError msg.errorCode = timeoutError msg.confirmed = false - if producer.publishConfirm != nil { - producer.publishConfirm <- []*ConfirmationStatus{msg} - } - delete(producer.unConfirmedMessages, msg.publishingId) + toRemove = append(toRemove, msg) + } + } + producer.mutex.RUnlock() + + if len(toRemove) > 0 { + producer.removeFromConfirmationStatus(toRemove) + if producer.publishConfirm != nil { + producer.publishConfirm <- toRemove } } - producer.mutex.Unlock() } time.Sleep(5 * time.Second) producer.flushUnConfirmedMessages(timeoutError, ConfirmationTimoutError) @@ -312,7 +348,7 @@ func (producer *Producer) startPublishTask() { } producer.pendingMessages.size += msg.unCompressedSize - producer.pendingMessages.messages = append(producer.pendingMessages.messages, msg) + producer.pendingMessages.messages = append(producer.pendingMessages.messages, &msg) if len(producer.pendingMessages.messages) >= (producer.options.BatchSize) { producer.sendBufferedMessages() } @@ -384,7 +420,7 @@ func (producer *Producer) assignPublishingID(message message.StreamMessage) int6 // BatchSend is the primitive method to send messages to the stream, the method Send prepares the messages and // calls BatchSend internally. func (producer *Producer) BatchSend(batchMessages []message.StreamMessage) error { - var messagesSequence = make([]messageSequence, len(batchMessages)) + var messagesSequence = make([]*messageSequence, len(batchMessages)) totalBufferToSend := 0 for i, batchMessage := range batchMessages { messageBytes, err := batchMessage.MarshalBinary() @@ -398,16 +434,17 @@ func (producer *Producer) BatchSend(batchMessages []message.StreamMessage) error sequence := producer.assignPublishingID(batchMessage) totalBufferToSend += len(messageBytes) - messagesSequence[i] = messageSequence{ + messagesSequence[i] = &messageSequence{ messageBytes: messageBytes, unCompressedSize: len(messageBytes), publishingId: sequence, filterValue: filterValue, + refMessage: &batchMessage, } - - producer.addUnConfirmed(sequence, batchMessage, producer.id) } + producer.addUnConfirmedSequences(messagesSequence, producer.GetID()) + if totalBufferToSend+initBufferPublishSize > producer.options.client.tuneState.requestedMaxFrameSize { for _, msg := range messagesSequence { @@ -432,11 +469,11 @@ func (producer *Producer) BatchSend(batchMessages []message.StreamMessage) error func (producer *Producer) GetID() uint8 { return producer.id } -func (producer *Producer) internalBatchSend(messagesSequence []messageSequence) error { +func (producer *Producer) internalBatchSend(messagesSequence []*messageSequence) error { return producer.internalBatchSendProdId(messagesSequence, producer.GetID()) } -func (producer *Producer) simpleAggregation(messagesSequence []messageSequence, b *bufio.Writer) { +func (producer *Producer) simpleAggregation(messagesSequence []*messageSequence, b *bufio.Writer) { for _, msg := range messagesSequence { r := msg.messageBytes writeBLong(b, msg.publishingId) // publishingId @@ -459,13 +496,15 @@ func (producer *Producer) subEntryAggregation(aggregation subEntries, b *bufio.W } } -func (producer *Producer) aggregateEntities(msgs []messageSequence, size int, compression Compression) (subEntries, error) { +func (producer *Producer) aggregateEntities(msgs []*messageSequence, size int, compression Compression) (subEntries, error) { subEntries := subEntries{} var entry *subEntry for _, msg := range msgs { if len(subEntries.items) == 0 || len(entry.messages) >= size { - entry = &subEntry{} + entry = &subEntry{ + messages: make([]*messageSequence, 0), + } entry.publishingId = -1 subEntries.items = append(subEntries.items, entry) } @@ -506,7 +545,7 @@ func (producer *Producer) aggregateEntities(msgs []messageSequence, size int, co /// the producer id is always the producer.GetID(). This function is needed only for testing // some condition, like simulate publish error, see -func (producer *Producer) internalBatchSendProdId(messagesSequence []messageSequence, producerID uint8) error { +func (producer *Producer) internalBatchSendProdId(messagesSequence []*messageSequence, producerID uint8) error { producer.options.client.socket.mutex.Lock() defer producer.options.client.socket.mutex.Unlock() if producer.getStatus() == closed { @@ -656,7 +695,7 @@ func (producer *Producer) GetName() string { return producer.options.Name } -func (producer *Producer) sendWithFilter(messagesSequence []messageSequence, producerID uint8) error { +func (producer *Producer) sendWithFilter(messagesSequence []*messageSequence, producerID uint8) error { frameHeaderLength := initBufferPublishSize var msgLen int for _, msg := range messagesSequence { diff --git a/pkg/stream/producer_test.go b/pkg/stream/producer_test.go index 941751bd..6216c2cb 100644 --- a/pkg/stream/producer_test.go +++ b/pkg/stream/producer_test.go @@ -555,11 +555,20 @@ var _ = Describe("Streaming Producers", func() { } }(chPublishError) - var messagesSequence = make([]messageSequence, 1) + var messagesSequence = make([]*messageSequence, 1) + + for i := 0; i < 1; i++ { + s := make([]byte, 50) + messagesSequence[i] = &messageSequence{ + messageBytes: s, + unCompressedSize: len(s), + } + } + msg := amqp.NewMessage([]byte("test")) msg.SetPublishingId(1) messageBytes, _ := msg.MarshalBinary() - messagesSequence[0] = messageSequence{ + messagesSequence[0] = &messageSequence{ messageBytes: messageBytes, unCompressedSize: len(messageBytes), } @@ -638,35 +647,81 @@ var _ = Describe("Streaming Producers", func() { NewProducerOptions().SetBatchPublishingDelay(100). SetSubEntrySize(77)) Expect(err).NotTo(HaveOccurred()) - messagesSequence := make([]messageSequence, 201) + messagesSequence := make([]*messageSequence, 201) + + for i := 0; i < 201; i++ { + s := make([]byte, 50) + messagesSequence[i] = &messageSequence{ + messageBytes: s, + unCompressedSize: len(s), + } + } + entries, err := producer.aggregateEntities(messagesSequence, producer.options.SubEntrySize, producer.options.Compression) Expect(err).NotTo(HaveOccurred()) Expect(len(entries.items)).To(Equal(3)) - messagesSequence = make([]messageSequence, 100) + messagesSequence = make([]*messageSequence, 100) + + for i := 0; i < 100; i++ { + + s := make([]byte, 50) + messagesSequence[i] = &messageSequence{ + messageBytes: s, + unCompressedSize: len(s), + } + } + entries, err = producer.aggregateEntities(messagesSequence, producer.options.SubEntrySize, producer.options.Compression) Expect(err).NotTo(HaveOccurred()) Expect(len(entries.items)).To(Equal(2)) - messagesSequence = make([]messageSequence, 1) + messagesSequence = make([]*messageSequence, 1) + + for i := 0; i < 1; i++ { + s := make([]byte, 50) + messagesSequence[i] = &messageSequence{ + messageBytes: s, + unCompressedSize: len(s), + } + } + entries, err = producer.aggregateEntities(messagesSequence, producer.options.SubEntrySize, producer.options.Compression) Expect(err).NotTo(HaveOccurred()) Expect(len(entries.items)).To(Equal(1)) - messagesSequence = make([]messageSequence, 1000) + messagesSequence = make([]*messageSequence, 1000) + + for i := 0; i < 1000; i++ { + s := make([]byte, 50) + messagesSequence[i] = &messageSequence{ + messageBytes: s, + unCompressedSize: len(s), + } + } + entries, err = producer.aggregateEntities(messagesSequence, producer.options.SubEntrySize, producer.options.Compression) Expect(err).NotTo(HaveOccurred()) Expect(len(entries.items)).To(Equal(13)) - messagesSequence = make([]messageSequence, 14) + messagesSequence = make([]*messageSequence, 14) + + for i := 0; i < 14; i++ { + s := make([]byte, 50) + messagesSequence[i] = &messageSequence{ + messageBytes: s, + unCompressedSize: len(s), + } + } + entries, err = producer.aggregateEntities(messagesSequence, 13, producer.options.Compression) Expect(err).NotTo(HaveOccurred()) diff --git a/pkg/stream/server_frame.go b/pkg/stream/server_frame.go index 62f55050..3f4ebf79 100644 --- a/pkg/stream/server_frame.go +++ b/pkg/stream/server_frame.go @@ -259,7 +259,6 @@ func (c *Client) handleConfirm(readProtocol *ReaderProtocol, r *bufio.Reader) in if m != nil { m.confirmed = true unConfirmed = append(unConfirmed, m) - producer.removeUnConfirmed(m.publishingId) // in case of sub-batch entry the client receives only // one publishingId (or sequence) @@ -267,20 +266,18 @@ func (c *Client) handleConfirm(readProtocol *ReaderProtocol, r *bufio.Reader) in for _, message := range m.linkedTo { message.confirmed = true unConfirmed = append(unConfirmed, message) - producer.removeUnConfirmed(message.publishingId) } } - //} else { - //logs.LogWarn("message %d not found in confirmation", seq) - //} publishingIdCount-- } - producer.mutex.Lock() + producer.removeFromConfirmationStatus(unConfirmed) + + //producer.mutex.Lock() if producer.publishConfirm != nil { producer.publishConfirm <- unConfirmed } - producer.mutex.Unlock() + //producer.mutex.Unlock() return 0 } diff --git a/pkg/stream/super_stream_consumer_test.go b/pkg/stream/super_stream_consumer_test.go index fdb5045d..2340c0f9 100644 --- a/pkg/stream/super_stream_consumer_test.go +++ b/pkg/stream/super_stream_consumer_test.go @@ -74,8 +74,7 @@ var _ = Describe("Super Stream Producer", Label("super-stream-consumer"), func() env, err := NewEnvironment(nil) Expect(err).NotTo(HaveOccurred()) - superStream := "first-super-stream-consumer" - + superStream := fmt.Sprintf("first-super-stream-consumer-%d", time.Now().Unix()) Expect(env.DeclareSuperStream(superStream, NewPartitionsOptions(3))).NotTo(HaveOccurred()) messagesHandler := func(consumerContext ConsumerContext, message *amqp.Message) {} @@ -104,7 +103,7 @@ var _ = Describe("Super Stream Producer", Label("super-stream-consumer"), func() It("validate super stream consumer ", func() { env, err := NewEnvironment(nil) Expect(err).NotTo(HaveOccurred()) - superStream := "validate-super-stream-consumer" + superStream := fmt.Sprintf("validate-super-stream-consumer-%d", time.Now().Unix()) Expect(env.DeclareSuperStream(superStream, NewPartitionsOptions(3))).NotTo(HaveOccurred()) messagesHandler := func(consumerContext ConsumerContext, message *amqp.Message) {} @@ -131,7 +130,7 @@ var _ = Describe("Super Stream Producer", Label("super-stream-consumer"), func() env, err := NewEnvironment(nil) Expect(err).NotTo(HaveOccurred()) - superStream := "consume-20messages-super-stream-consumer" + superStream := fmt.Sprintf("consume-20messages-super-stream-consumer-%d", time.Now().Unix()) Expect(env.DeclareSuperStream(superStream, NewPartitionsOptions(3))).NotTo(HaveOccurred()) var receivedMessages int32 @@ -233,7 +232,7 @@ var _ = Describe("Super Stream Producer", Label("super-stream-consumer"), func() env, err := NewEnvironment(nil) Expect(err).NotTo(HaveOccurred()) - const superStream = "reconnect-super-stream-consumer" + var superStream = fmt.Sprintf("reconnect-super-stream-consumer-%d", time.Now().Unix()) Expect(env.DeclareSuperStream(superStream, NewPartitionsOptions(3). SetBalancedLeaderLocator())).NotTo(HaveOccurred()) @@ -288,7 +287,7 @@ var _ = Describe("Super Stream Producer", Label("super-stream-consumer"), func() env, err := NewEnvironment(nil) Expect(err).NotTo(HaveOccurred()) - superStream := "sac-super-stream-the-second-should-restart-consume" + superStream := fmt.Sprintf("sac-super-stream-the-second-should-restart-consume-%d", time.Now().Unix()) Expect(env.DeclareSuperStream(superStream, NewPartitionsOptions(2))).NotTo(HaveOccurred()) const appName = "MyApplication" @@ -360,7 +359,7 @@ var _ = Describe("Super Stream Producer", Label("super-stream-consumer"), func() env, err := NewEnvironment(nil) Expect(err).NotTo(HaveOccurred()) - superStream := "filtering-super-stream-should-consume-only-one-country" + superStream := fmt.Sprintf("filtering-super-stream-should-consume-only-one-country-%d", time.Now().Unix()) Expect(env.DeclareSuperStream(superStream, NewPartitionsOptions(2))).NotTo(HaveOccurred()) @@ -372,7 +371,7 @@ var _ = Describe("Super Stream Producer", Label("super-stream-consumer"), func() }))) Expect(err).NotTo(HaveOccurred()) - for i := 0; i < 25; i++ { + for i := 0; i < 7; i++ { msg := amqp.NewMessage(make([]byte, 0)) msg.ApplicationProperties = map[string]interface{}{"county": "italy"} msg.Properties = &amqp.MessageProperties{ @@ -383,9 +382,9 @@ var _ = Describe("Super Stream Producer", Label("super-stream-consumer"), func() // the sleep is to be sure the messages are stored in a chunk // so the filter will be applied, so the first chunk will contain only Italy - time.Sleep(1 * time.Second) + time.Sleep(1500 * time.Millisecond) - for i := 0; i < 25; i++ { + for i := 0; i < 6; i++ { msg := amqp.NewMessage(make([]byte, 0)) msg.ApplicationProperties = map[string]interface{}{"county": "spain"} msg.Properties = &amqp.MessageProperties{ @@ -394,7 +393,7 @@ var _ = Describe("Super Stream Producer", Label("super-stream-consumer"), func() Expect(superProducer.Send(msg)).NotTo(HaveOccurred()) } - time.Sleep(500 * time.Millisecond) + time.Sleep(1500 * time.Millisecond) // we don't need to apply any post filter here // the server side filter is enough @@ -412,7 +411,7 @@ var _ = Describe("Super Stream Producer", Label("super-stream-consumer"), func() Expect(err).NotTo(HaveOccurred()) time.Sleep(500 * time.Millisecond) - Eventually(func() int32 { return atomic.LoadInt32(&consumerItaly) }).WithPolling(300 * time.Millisecond).WithTimeout(5 * time.Second).Should(Equal(int32(25))) + Eventually(func() int32 { return atomic.LoadInt32(&consumerItaly) }).WithPolling(300 * time.Millisecond).WithTimeout(5 * time.Second).Should(Equal(int32(7))) Expect(superProducer.Close()).NotTo(HaveOccurred()) Expect(superStreamConsumer.Close()).NotTo(HaveOccurred()) @@ -428,7 +427,7 @@ var _ = Describe("Super Stream Producer", Label("super-stream-consumer"), func() env, err := NewEnvironment(nil) Expect(err).NotTo(HaveOccurred()) - superStream := "filtering-super-stream-should-consume-only-one-country" + superStream := fmt.Sprintf("filtering-super-stream-should-consume-only-one-country-%d", time.Now().Unix()) Expect(env.DeclareSuperStream(superStream, NewPartitionsOptions(2))).NotTo(HaveOccurred()) @@ -515,7 +514,7 @@ var _ = Describe("Super Stream Producer", Label("super-stream-consumer"), func() env, err := NewEnvironment(nil) Expect(err).NotTo(HaveOccurred()) - superStream := "super-stream-consumer-with-autocommit" + superStream := fmt.Sprintf("super-stream-consumer-with-autocommit-%d", time.Now().Unix()) Expect(env.DeclareSuperStream(superStream, NewPartitionsOptions(2))).NotTo(HaveOccurred()) diff --git a/pkg/stream/super_stream_producer.go b/pkg/stream/super_stream_producer.go index 47fe89c8..fc7c2504 100644 --- a/pkg/stream/super_stream_producer.go +++ b/pkg/stream/super_stream_producer.go @@ -45,6 +45,10 @@ func NewHashRoutingStrategy(routingKeyExtractor func(message message.StreamMessa func (h *HashRoutingStrategy) Route(message message.StreamMessage, partitions []string) ([]string, error) { + if message == nil { + return nil, fmt.Errorf("message is nil") + } + key := h.RoutingKeyExtractor(message) murmurHash := murmur3.New32WithSeed(SEED) _, _ = murmurHash.Write([]byte(key)) @@ -213,6 +217,7 @@ func (s *SuperStreamProducer) init() error { partitions, err := s.env.QueryPartitions(s.SuperStream) s.partitions = partitions + if err != nil { return err } diff --git a/pkg/stream/super_stream_producer_test.go b/pkg/stream/super_stream_producer_test.go index e50defad..a06de884 100644 --- a/pkg/stream/super_stream_producer_test.go +++ b/pkg/stream/super_stream_producer_test.go @@ -103,7 +103,7 @@ var _ = Describe("Super Stream Producer", Label("super-stream"), func() { env, err := NewEnvironment(nil) Expect(err).NotTo(HaveOccurred()) - const superStream = "first-super-stream-producer" + var superStream = fmt.Sprintf("first-super-stream-producer-%d", time.Now().Unix()) Expect(env.DeclareSuperStream(superStream, NewPartitionsOptions(3). SetBalancedLeaderLocator(). @@ -134,7 +134,7 @@ var _ = Describe("Super Stream Producer", Label("super-stream"), func() { Expect(err).NotTo(HaveOccurred()) // we do this test to be sure that the producer is able to Send messages to all the partitions // the same was done in .NET client and python client - const superStream = "super-stream-send-messages-to-all-partitions" + var superStream = fmt.Sprintf("super-stream-send-messages-to-all-partitions-%d", time.Now().Unix()) msgReceived := make(map[string]int) mutex := sync.Mutex{} @@ -202,7 +202,7 @@ var _ = Describe("Super Stream Producer", Label("super-stream"), func() { It("should handle three close ( one for partition )", func() { env, err := NewEnvironment(nil) Expect(err).NotTo(HaveOccurred()) - const superStream = "close-super-stream-producer" + var superStream = fmt.Sprintf("close-super-stream-producer-%d", time.Now().Unix()) var closedMap = make(map[string]bool) mutex := sync.Mutex{} Expect(env.DeclareSuperStream(superStream, @@ -244,7 +244,7 @@ var _ = Describe("Super Stream Producer", Label("super-stream"), func() { env, err := NewEnvironment(nil) Expect(err).NotTo(HaveOccurred()) - const superStream = "reconnect-super-stream-producer" + var superStream = fmt.Sprintf("reconnect-super-stream-producer-%d", time.Now().Unix()) Expect(env.DeclareSuperStream(superStream, NewPartitionsOptions(3). SetBalancedLeaderLocator())).NotTo(HaveOccurred()) @@ -309,7 +309,7 @@ var _ = Describe("Super Stream Producer", Label("super-stream"), func() { env, err := NewEnvironment(nil) Expect(err).NotTo(HaveOccurred()) - const superStream = "key-super-stream-producer" + var superStream = fmt.Sprintf("key-super-stream-producer-%d", time.Now().Unix()) Expect(env.DeclareSuperStream(superStream, options. SetMaxSegmentSizeBytes(ByteCapacity{}.GB(1)). SetMaxAge(3*time.Hour). @@ -323,19 +323,19 @@ var _ = Describe("Super Stream Producer", Label("super-stream"), func() { Expect(err).NotTo(HaveOccurred()) Expect(route).NotTo(BeNil()) Expect(route).To(HaveLen(1)) - Expect(route[0]).To(Equal("key-super-stream-producer-italy")) + Expect(route[0]).To(Equal(fmt.Sprintf("%s-italy", superStream))) route, err = client.queryRoute(superStream, "spain") Expect(err).NotTo(HaveOccurred()) Expect(route).NotTo(BeNil()) Expect(route).To(HaveLen(1)) - Expect(route[0]).To(Equal("key-super-stream-producer-spain")) + Expect(route[0]).To(Equal(fmt.Sprintf("%s-spain", superStream))) route, err = client.queryRoute(superStream, "france") Expect(err).NotTo(HaveOccurred()) Expect(route).NotTo(BeNil()) Expect(route).To(HaveLen(1)) - Expect(route[0]).To(Equal("key-super-stream-producer-france")) + Expect(route[0]).To(Equal(fmt.Sprintf("%s-france", superStream))) // here we test the case where the key is not found // the client should return an empty list @@ -368,7 +368,7 @@ var _ = Describe("Super Stream Producer", Label("super-stream"), func() { env, err := NewEnvironment(nil) Expect(err).NotTo(HaveOccurred()) - const superStream = "key-super-stream-producer-with-3-keys" + var superStream = fmt.Sprintf("key-super-stream-producer-with-3-keys-%d", time.Now().Unix()) countries := []string{"italy", "france", "spain"} Expect(env.DeclareSuperStream(superStream, NewBindingsOptions(countries))).NotTo(HaveOccurred()) @@ -430,7 +430,7 @@ var _ = Describe("Super Stream Producer", Label("super-stream"), func() { env, err := NewEnvironment(nil) Expect(err).NotTo(HaveOccurred()) - const superStream = "custom-routing-strategy" + var superStream = fmt.Sprintf("custom-routing-strategy-%d", time.Now().Unix()) Expect(env.DeclareSuperStream(superStream, NewPartitionsOptions(3))).NotTo(HaveOccurred()) superProducer, err := env.NewSuperStreamProducer(superStream, NewSuperStreamProducerOptions( @@ -455,7 +455,7 @@ var _ = Describe("Super Stream Producer", Label("super-stream"), func() { // Test is to validate the error when the producer is already connected env, err := NewEnvironment(nil) Expect(err).NotTo(HaveOccurred()) - const superStream = "already-connected-super-stream-producer" + var superStream = fmt.Sprintf("already-connected-super-stream-producer-%d", time.Now().Unix()) Expect(env.DeclareSuperStream(superStream, NewPartitionsOptions(3))).NotTo(HaveOccurred()) superProducer, err := env.NewSuperStreamProducer(superStream, NewSuperStreamProducerOptions(