diff --git a/README.md b/README.md index a09e96a..f08b532 100644 --- a/README.md +++ b/README.md @@ -296,6 +296,17 @@ p.Produce( ) ``` + +### Produce using partition key +The partition key will be used to produce messages to a spacific partition. + +```go +p.Produce( + "", + memphis.ProducerPartitionKey() +) +``` + ### Destroying a Producer ```go @@ -346,7 +357,9 @@ func handler(msgs []*memphis.Msg, err error, ctx context.Context) { } } -consumer.Consume(handler) +consumer.Consume(handler, + memphis.ConsumerPartitionKey() // use the partition key to consume from a spacific partition (if not specified consume in a Round Robin fashion) +) ``` ### Fetch a single batch of messages @@ -360,13 +373,17 @@ msgs, err := conn.FetchMessages("", "", memphis.FetchConsumerErrorHandler(func(*Consumer, error){}) memphis.FetchStartConsumeFromSeq()// start consuming from a specific sequence. defaults to 1 memphis.FetchLastMessages()// consume the last N messages, defaults to -1 (all messages in the station)) + memphis.FetchPartitionKey()// use the partition key to consume from a spacific partition (if not specified consume in a Round Robin fashion) ``` ### Fetch a single batch of messages after creating a consumer `prefetch = true` will prefetch next batch of messages and save it in memory for future Fetch() request
Note: Use a higher MaxAckTime as the messages will sit in a local cache for some time before processing ```go -msgs, err := consumer.Fetch( int, bool) +msgs, err := consumer.Fetch( int, + bool, + memphis.ConsumerPartitionKey() // use the partition key to consume from a spacific partition (if not specified consume in a Round Robin fashion) + ) ``` ### Acknowledging a Message diff --git a/connect.go b/connect.go index 1f240cf..1155e0d 100644 --- a/connect.go +++ b/connect.go @@ -32,12 +32,14 @@ import ( "github.com/gofrs/uuid" "github.com/nats-io/nats.go" + "github.com/spaolacci/murmur3" ) const ( sdkClientsUpdatesSubject = "$memphis_sdk_clients_updates" maxBatchSize = 5000 memphisGlobalAccountName = "$memphis" + SEED = 1234 ) var stationUpdatesSubsLock sync.Mutex @@ -94,6 +96,7 @@ type FetchOpts struct { StartConsumeFromSequence uint64 LastMessages int64 Prefetch bool + FetchPartitionKey string } // getDefaultConsumerOptions - returns default configuration options for consumers. @@ -109,6 +112,7 @@ func getDefaultFetchOptions() FetchOpts { StartConsumeFromSequence: 1, LastMessages: -1, Prefetch: false, + FetchPartitionKey: "", } } @@ -799,7 +803,7 @@ func (c *Conn) FetchMessages(stationName string, consumerName string, opts ...Fe } else { consumer = cons } - msgs, err := consumer.Fetch(defaultOpts.BatchSize, defaultOpts.Prefetch) + msgs, err := consumer.Fetch(defaultOpts.BatchSize, defaultOpts.Prefetch, ConsumerPartitionKey(defaultOpts.FetchPartitionKey)) if err != nil { return nil, err } @@ -814,6 +818,14 @@ func FetchConsumerGroup(cg string) FetchOpt { } } +// PartitionKey - partition key to consume from. +func FetchPartitionKey(partitionKey string) FetchOpt { + return func(opts *FetchOpts) error { + opts.FetchPartitionKey = partitionKey + return nil + } +} + // BatchSize - pull batch size. func FetchBatchSize(batchSize int) FetchOpt { return func(opts *FetchOpts) error { @@ -883,3 +895,13 @@ func init() { applicationId = appId.String() } } + +func (c *Conn) GetPartitionFromKey(key string, stationName string) (int, error) { + mur3 := murmur3.New32WithSeed(SEED) + _, err := mur3.Write([]byte(key)) + if err != nil { + return -1, err + } + PartitionIndex := int(mur3.Sum32()) % len(c.stationPartitions[stationName].PartitionsList) + return c.stationPartitions[stationName].PartitionsList[PartitionIndex], nil +} diff --git a/consumer.go b/consumer.go index 3519418..94d6599 100644 --- a/consumer.go +++ b/consumer.go @@ -303,6 +303,7 @@ func (opts *ConsumerOpts) createConsumer(c *Conn) (*Consumer, error) { sn := getInternalName(consumer.stationName) durable := getInternalName(consumer.ConsumerGroup) + if len(consumer.conn.stationPartitions[sn].PartitionsList) == 0 { consumer.subscriptions = make(map[int]*nats.Subscription, 1) subj := sn + ".final" @@ -382,9 +383,11 @@ func (c *Consumer) pingConsumer() { }(sub) } wg.Wait() - if generalErr != nil && (strings.Contains(generalErr.Error(), "consumer not found") || strings.Contains(generalErr.Error(), "stream not found")) { - c.subscriptionActive = false - c.callErrHandler(ConsumerErrStationUnreachable) + if generalErr != nil { + if strings.Contains(generalErr.Error(), "consumer not found") || strings.Contains(generalErr.Error(), "stream not found") { + c.subscriptionActive = false + c.callErrHandler(ConsumerErrStationUnreachable) + } } case <-c.pingQuit: ticker.Stop() @@ -401,11 +404,44 @@ func (c *Consumer) SetContext(ctx context.Context) { // ConsumeHandler - handler for consumed messages type ConsumeHandler func([]*Msg, error, context.Context) +// ConsumingOpts - configuration options for consuming messages +type ConsumingOpts struct { + ConsumerPartitionKey string +} + +type ConsumingOpt func(*ConsumingOpts) error + +// ConsumerPartitionKey - Partition key for the consumer to consume from +func ConsumerPartitionKey(ConsumerPartitionKey string) ConsumingOpt { + return func(opts *ConsumingOpts) error { + opts.ConsumerPartitionKey = ConsumerPartitionKey + return nil + } +} + +func getDefaultConsumingOptions() ConsumingOpts { + return ConsumingOpts{ + ConsumerPartitionKey: "", + } +} + // Consumer.Consume - start consuming messages according to the interval configured in the consumer object. // When a batch is consumed the handlerFunc will be called. -func (c *Consumer) Consume(handlerFunc ConsumeHandler) error { - go func(c *Consumer) { - msgs, err := c.fetchSubscription() +func (c *Consumer) Consume(handlerFunc ConsumeHandler, opts ...ConsumingOpt) error { + + defaultOpts := getDefaultConsumingOptions() + + for _, opt := range opts { + if opt != nil { + if err := opt(&defaultOpts); err != nil { + return memphisError(err) + } + } + } + + go func(c *Consumer, partitionKey string) { + + msgs, err := c.fetchSubscription(partitionKey) handlerFunc(msgs, memphisError(err), c.context) c.dlsHandlerFunc = handlerFunc ticker := time.NewTicker(c.PullInterval) @@ -421,13 +457,13 @@ func (c *Consumer) Consume(handlerFunc ConsumeHandler) error { select { case <-ticker.C: - msgs, err := c.fetchSubscription() + msgs, err := c.fetchSubscription(partitionKey) handlerFunc(msgs, memphisError(err), nil) case <-c.consumeQuit: return } } - }(c) + }(c, defaultOpts.ConsumerPartitionKey) c.consumeActive = true return nil } @@ -442,15 +478,26 @@ func (c *Consumer) StopConsume() { c.consumeActive = false } -func (c *Consumer) fetchSubscription() ([]*Msg, error) { +func (c *Consumer) fetchSubscription(partitionKey string) ([]*Msg, error) { + if !c.subscriptionActive { return nil, memphisError(errors.New("station unreachable")) } wrappedMsgs := make([]*Msg, 0, c.BatchSize) partitionNumber := 1 + if len(c.subscriptions) > 1 { - partitionNumber = c.PartitionGenerator.Next() + if partitionKey != "" { + partitionFromKey, err := c.conn.GetPartitionFromKey(partitionKey, c.stationName) + if err != nil { + return nil, memphisError(err) + } + partitionNumber = partitionFromKey + } else { + partitionNumber = c.PartitionGenerator.Next() + } } + msgs, err := c.subscriptions[partitionNumber].Fetch(c.BatchSize) if err != nil && err != nats.ErrTimeout { c.subscriptionActive = false @@ -468,13 +515,14 @@ type fetchResult struct { err error } -func (c *Consumer) fetchSubscriprionWithTimeout() ([]*Msg, error) { +func (c *Consumer) fetchSubscriprionWithTimeout(partitionKey string) ([]*Msg, error) { timeoutDuration := c.BatchMaxTimeToWait out := make(chan fetchResult, 1) - go func() { - msgs, err := c.fetchSubscription() + + go func(partitionKey string) { + msgs, err := c.fetchSubscription(partitionKey) out <- fetchResult{msgs: msgs, err: memphisError(err)} - }() + }(partitionKey) select { case <-time.After(timeoutDuration): return nil, memphisError(errors.New("fetch timed out")) @@ -485,11 +533,21 @@ func (c *Consumer) fetchSubscriprionWithTimeout() ([]*Msg, error) { } // Fetch - immediately fetch a batch of messages. -func (c *Consumer) Fetch(batchSize int, prefetch bool) ([]*Msg, error) { +func (c *Consumer) Fetch(batchSize int, prefetch bool, opts ...ConsumingOpt) ([]*Msg, error) { if batchSize > maxBatchSize { return nil, memphisError(errors.New("Batch size can not be greater than " + strconv.Itoa(maxBatchSize))) } + defaultOpts := getDefaultConsumingOptions() + + for _, opt := range opts { + if opt != nil { + if err := opt(&defaultOpts); err != nil { + return nil, memphisError(err) + } + } + } + c.BatchSize = batchSize var msgs []*Msg if len(c.dlsMsgs) > 0 { @@ -523,15 +581,15 @@ func (c *Consumer) Fetch(batchSize int, prefetch bool) ([]*Msg, error) { } c.conn.prefetchedMsgs.lock.Unlock() if prefetch { - go c.prefetchMsgs() + go c.prefetchMsgs(defaultOpts.ConsumerPartitionKey) } if len(msgs) > 0 { return msgs, nil } - return c.fetchSubscriprionWithTimeout() + return c.fetchSubscriprionWithTimeout(defaultOpts.ConsumerPartitionKey) } -func (c *Consumer) prefetchMsgs() { +func (c *Consumer) prefetchMsgs(partitionKey string) { c.conn.prefetchedMsgs.lock.Lock() defer c.conn.prefetchedMsgs.lock.Unlock() lowerCaseStationName := getLowerCaseName(c.stationName) @@ -541,7 +599,7 @@ func (c *Consumer) prefetchMsgs() { if _, ok := c.conn.prefetchedMsgs.msgs[lowerCaseStationName][c.ConsumerGroup]; !ok { c.conn.prefetchedMsgs.msgs[lowerCaseStationName][c.ConsumerGroup] = make([]*Msg, 0) } - msgs, err := c.fetchSubscriprionWithTimeout() + msgs, err := c.fetchSubscriprionWithTimeout(partitionKey) if err == nil { c.conn.prefetchedMsgs.msgs[lowerCaseStationName][c.ConsumerGroup] = append(c.conn.prefetchedMsgs.msgs[lowerCaseStationName][c.ConsumerGroup], msgs...) } diff --git a/go.mod b/go.mod index 2da33d1..ec0938a 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/nats-io/nats-server/v2 v2.9.5 // indirect github.com/nats-io/nkeys v0.4.4 // indirect github.com/nats-io/nuid v1.0.1 // indirect + github.com/spaolacci/murmur3 v1.1.0 // indirect golang.org/x/crypto v0.6.0 // indirect golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect golang.org/x/sys v0.5.0 // indirect diff --git a/go.sum b/go.sum index 916deef..21d532d 100644 --- a/go.sum +++ b/go.sum @@ -88,6 +88,8 @@ github.com/santhosh-tekuri/jsonschema/v5 v5.1.0 h1:wSUNu/w/7OQ0Y3NVnfTU5uxzXY4uM github.com/santhosh-tekuri/jsonschema/v5 v5.1.0/go.mod h1:FKdcjfQW6rpZSnxxUvEA5H/cDPdvJ/SZJQLWWXWGrZ0= github.com/santhosh-tekuri/jsonschema/v5 v5.3.0 h1:uIkTLo0AGRc8l7h5l9r+GcYi9qfVPt6lD4/bhmzfiKo= github.com/santhosh-tekuri/jsonschema/v5 v5.3.0/go.mod h1:FKdcjfQW6rpZSnxxUvEA5H/cDPdvJ/SZJQLWWXWGrZ0= +github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= +github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= diff --git a/producer.go b/producer.go index a5bf44e..f1aecd9 100644 --- a/producer.go +++ b/producer.go @@ -316,10 +316,11 @@ type Headers struct { // ProduceOpts - configuration options for produce operations. type ProduceOpts struct { - Message any - AckWaitSec int - MsgHeaders Headers - AsyncProduce bool + Message any + AckWaitSec int + MsgHeaders Headers + AsyncProduce bool + ProducerPartitionKey string } // ProduceOpt - a function on the options for produce operations. @@ -380,11 +381,20 @@ func (opts *ProduceOpts) produce(p *Producer) error { var streamName string sn := getInternalName(p.stationName) + if len(p.conn.stationPartitions[sn].PartitionsList) == 1 { streamName = fmt.Sprintf("%v$%v", sn, p.conn.stationPartitions[sn].PartitionsList[0]) } else if len(p.conn.stationPartitions[sn].PartitionsList) > 1 { - partitionNumber := p.PartitionGenerator.Next() - streamName = fmt.Sprintf("%v$%v", sn, partitionNumber) + if opts.ProducerPartitionKey != "" { + partitionNumber, err := p.conn.GetPartitionFromKey(opts.ProducerPartitionKey, sn) + if err != nil { + return memphisError(fmt.Errorf("failed to get partition from key")) + } + streamName = fmt.Sprintf("%v$%v", sn, partitionNumber) + } else { + partitionNumber := p.PartitionGenerator.Next() + streamName = fmt.Sprintf("%v$%v", sn, partitionNumber) + } } else { streamName = sn } @@ -544,6 +554,14 @@ func AckWaitSec(ackWaitSec int) ProduceOpt { } } +// ProducerPartitionKey - set a partition key for a message +func ProducerPartitionKey(partitionKey string) ProduceOpt { + return func(opts *ProduceOpts) error { + opts.ProducerPartitionKey = partitionKey + return nil + } +} + // MsgHeaders - set headers to a message func MsgHeaders(hdrs Headers) ProduceOpt { return func(opts *ProduceOpts) error {