Skip to content

Commit

Permalink
Merge pull request #121 from memphisdev/partition_key
Browse files Browse the repository at this point in the history
producer / consumer  with partition key
  • Loading branch information
daniel-davidd authored Sep 10, 2023
2 parents b71935f + dd8828d commit afedbb6
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 28 deletions.
21 changes: 19 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,17 @@ p.Produce(
)
```


### Produce using partition key
The partition key will be used to produce messages to a spacific partition.

```go
p.Produce(
"<message in []byte or map[string]interface{}/[]byte or protoreflect.ProtoMessage or map[string]interface{}(schema validated station - protobuf)/struct with json tags or map[string]interface{} or interface{}(schema validated station - json schema) or []byte/string (schema validated station - graphql schema) or []byte or map[string]interface{} or struct with avro tags(schema validated station - avro schema)>",
memphis.ProducerPartitionKey(<string>)
)
```

### Destroying a Producer

```go
Expand Down Expand Up @@ -346,7 +357,9 @@ func handler(msgs []*memphis.Msg, err error, ctx context.Context) {
}
}

consumer.Consume(handler)
consumer.Consume(handler,
memphis.ConsumerPartitionKey(<string>) // use the partition key to consume from a spacific partition (if not specified consume in a Round Robin fashion)
)
```

### Fetch a single batch of messages
Expand All @@ -360,13 +373,17 @@ msgs, err := conn.FetchMessages("<station-name>", "<consumer-name>",
memphis.FetchConsumerErrorHandler(func(*Consumer, error){})
memphis.FetchStartConsumeFromSeq(<uint64>)// start consuming from a specific sequence. defaults to 1
memphis.FetchLastMessages(<int64>)// consume the last N messages, defaults to -1 (all messages in the station))
memphis.FetchPartitionKey(<string>)// use the partition key to consume from a spacific partition (if not specified consume in a Round Robin fashion)
```

### Fetch a single batch of messages after creating a consumer
`prefetch = true` will prefetch next batch of messages and save it in memory for future Fetch() request<br>
Note: Use a higher MaxAckTime as the messages will sit in a local cache for some time before processing
```go
msgs, err := consumer.Fetch(<batch-size> int, <prefetch> bool)
msgs, err := consumer.Fetch(<batch-size> int,
<prefetch> bool,
memphis.ConsumerPartitionKey(<string>) // use the partition key to consume from a spacific partition (if not specified consume in a Round Robin fashion)
)
```

### Acknowledging a Message
Expand Down
24 changes: 23 additions & 1 deletion connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ import (

"github.com/gofrs/uuid"
"github.com/nats-io/nats.go"
"github.com/spaolacci/murmur3"
)

const (
sdkClientsUpdatesSubject = "$memphis_sdk_clients_updates"
maxBatchSize = 5000
memphisGlobalAccountName = "$memphis"
SEED = 1234
)

var stationUpdatesSubsLock sync.Mutex
Expand Down Expand Up @@ -94,6 +96,7 @@ type FetchOpts struct {
StartConsumeFromSequence uint64
LastMessages int64
Prefetch bool
FetchPartitionKey string
}

// getDefaultConsumerOptions - returns default configuration options for consumers.
Expand All @@ -109,6 +112,7 @@ func getDefaultFetchOptions() FetchOpts {
StartConsumeFromSequence: 1,
LastMessages: -1,
Prefetch: false,
FetchPartitionKey: "",
}
}

Expand Down Expand Up @@ -799,7 +803,7 @@ func (c *Conn) FetchMessages(stationName string, consumerName string, opts ...Fe
} else {
consumer = cons
}
msgs, err := consumer.Fetch(defaultOpts.BatchSize, defaultOpts.Prefetch)
msgs, err := consumer.Fetch(defaultOpts.BatchSize, defaultOpts.Prefetch, ConsumerPartitionKey(defaultOpts.FetchPartitionKey))
if err != nil {
return nil, err
}
Expand All @@ -814,6 +818,14 @@ func FetchConsumerGroup(cg string) FetchOpt {
}
}

// PartitionKey - partition key to consume from.
func FetchPartitionKey(partitionKey string) FetchOpt {
return func(opts *FetchOpts) error {
opts.FetchPartitionKey = partitionKey
return nil
}
}

// BatchSize - pull batch size.
func FetchBatchSize(batchSize int) FetchOpt {
return func(opts *FetchOpts) error {
Expand Down Expand Up @@ -883,3 +895,13 @@ func init() {
applicationId = appId.String()
}
}

func (c *Conn) GetPartitionFromKey(key string, stationName string) (int, error) {
mur3 := murmur3.New32WithSeed(SEED)
_, err := mur3.Write([]byte(key))
if err != nil {
return -1, err
}
PartitionIndex := int(mur3.Sum32()) % len(c.stationPartitions[stationName].PartitionsList)
return c.stationPartitions[stationName].PartitionsList[PartitionIndex], nil
}
96 changes: 77 additions & 19 deletions consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ func (opts *ConsumerOpts) createConsumer(c *Conn) (*Consumer, error) {
sn := getInternalName(consumer.stationName)

durable := getInternalName(consumer.ConsumerGroup)

if len(consumer.conn.stationPartitions[sn].PartitionsList) == 0 {
consumer.subscriptions = make(map[int]*nats.Subscription, 1)
subj := sn + ".final"
Expand Down Expand Up @@ -382,9 +383,11 @@ func (c *Consumer) pingConsumer() {
}(sub)
}
wg.Wait()
if generalErr != nil && (strings.Contains(generalErr.Error(), "consumer not found") || strings.Contains(generalErr.Error(), "stream not found")) {
c.subscriptionActive = false
c.callErrHandler(ConsumerErrStationUnreachable)
if generalErr != nil {
if strings.Contains(generalErr.Error(), "consumer not found") || strings.Contains(generalErr.Error(), "stream not found") {
c.subscriptionActive = false
c.callErrHandler(ConsumerErrStationUnreachable)
}
}
case <-c.pingQuit:
ticker.Stop()
Expand All @@ -401,11 +404,44 @@ func (c *Consumer) SetContext(ctx context.Context) {
// ConsumeHandler - handler for consumed messages
type ConsumeHandler func([]*Msg, error, context.Context)

// ConsumingOpts - configuration options for consuming messages
type ConsumingOpts struct {
ConsumerPartitionKey string
}

type ConsumingOpt func(*ConsumingOpts) error

// ConsumerPartitionKey - Partition key for the consumer to consume from
func ConsumerPartitionKey(ConsumerPartitionKey string) ConsumingOpt {
return func(opts *ConsumingOpts) error {
opts.ConsumerPartitionKey = ConsumerPartitionKey
return nil
}
}

func getDefaultConsumingOptions() ConsumingOpts {
return ConsumingOpts{
ConsumerPartitionKey: "",
}
}

// Consumer.Consume - start consuming messages according to the interval configured in the consumer object.
// When a batch is consumed the handlerFunc will be called.
func (c *Consumer) Consume(handlerFunc ConsumeHandler) error {
go func(c *Consumer) {
msgs, err := c.fetchSubscription()
func (c *Consumer) Consume(handlerFunc ConsumeHandler, opts ...ConsumingOpt) error {

defaultOpts := getDefaultConsumingOptions()

for _, opt := range opts {
if opt != nil {
if err := opt(&defaultOpts); err != nil {
return memphisError(err)
}
}
}

go func(c *Consumer, partitionKey string) {

msgs, err := c.fetchSubscription(partitionKey)
handlerFunc(msgs, memphisError(err), c.context)
c.dlsHandlerFunc = handlerFunc
ticker := time.NewTicker(c.PullInterval)
Expand All @@ -421,13 +457,13 @@ func (c *Consumer) Consume(handlerFunc ConsumeHandler) error {

select {
case <-ticker.C:
msgs, err := c.fetchSubscription()
msgs, err := c.fetchSubscription(partitionKey)
handlerFunc(msgs, memphisError(err), nil)
case <-c.consumeQuit:
return
}
}
}(c)
}(c, defaultOpts.ConsumerPartitionKey)
c.consumeActive = true
return nil
}
Expand All @@ -442,15 +478,26 @@ func (c *Consumer) StopConsume() {
c.consumeActive = false
}

func (c *Consumer) fetchSubscription() ([]*Msg, error) {
func (c *Consumer) fetchSubscription(partitionKey string) ([]*Msg, error) {

if !c.subscriptionActive {
return nil, memphisError(errors.New("station unreachable"))
}
wrappedMsgs := make([]*Msg, 0, c.BatchSize)
partitionNumber := 1

if len(c.subscriptions) > 1 {
partitionNumber = c.PartitionGenerator.Next()
if partitionKey != "" {
partitionFromKey, err := c.conn.GetPartitionFromKey(partitionKey, c.stationName)
if err != nil {
return nil, memphisError(err)
}
partitionNumber = partitionFromKey
} else {
partitionNumber = c.PartitionGenerator.Next()
}
}

msgs, err := c.subscriptions[partitionNumber].Fetch(c.BatchSize)
if err != nil && err != nats.ErrTimeout {
c.subscriptionActive = false
Expand All @@ -468,13 +515,14 @@ type fetchResult struct {
err error
}

func (c *Consumer) fetchSubscriprionWithTimeout() ([]*Msg, error) {
func (c *Consumer) fetchSubscriprionWithTimeout(partitionKey string) ([]*Msg, error) {
timeoutDuration := c.BatchMaxTimeToWait
out := make(chan fetchResult, 1)
go func() {
msgs, err := c.fetchSubscription()

go func(partitionKey string) {
msgs, err := c.fetchSubscription(partitionKey)
out <- fetchResult{msgs: msgs, err: memphisError(err)}
}()
}(partitionKey)
select {
case <-time.After(timeoutDuration):
return nil, memphisError(errors.New("fetch timed out"))
Expand All @@ -485,11 +533,21 @@ func (c *Consumer) fetchSubscriprionWithTimeout() ([]*Msg, error) {
}

// Fetch - immediately fetch a batch of messages.
func (c *Consumer) Fetch(batchSize int, prefetch bool) ([]*Msg, error) {
func (c *Consumer) Fetch(batchSize int, prefetch bool, opts ...ConsumingOpt) ([]*Msg, error) {
if batchSize > maxBatchSize {
return nil, memphisError(errors.New("Batch size can not be greater than " + strconv.Itoa(maxBatchSize)))
}

defaultOpts := getDefaultConsumingOptions()

for _, opt := range opts {
if opt != nil {
if err := opt(&defaultOpts); err != nil {
return nil, memphisError(err)
}
}
}

c.BatchSize = batchSize
var msgs []*Msg
if len(c.dlsMsgs) > 0 {
Expand Down Expand Up @@ -523,15 +581,15 @@ func (c *Consumer) Fetch(batchSize int, prefetch bool) ([]*Msg, error) {
}
c.conn.prefetchedMsgs.lock.Unlock()
if prefetch {
go c.prefetchMsgs()
go c.prefetchMsgs(defaultOpts.ConsumerPartitionKey)
}
if len(msgs) > 0 {
return msgs, nil
}
return c.fetchSubscriprionWithTimeout()
return c.fetchSubscriprionWithTimeout(defaultOpts.ConsumerPartitionKey)
}

func (c *Consumer) prefetchMsgs() {
func (c *Consumer) prefetchMsgs(partitionKey string) {
c.conn.prefetchedMsgs.lock.Lock()
defer c.conn.prefetchedMsgs.lock.Unlock()
lowerCaseStationName := getLowerCaseName(c.stationName)
Expand All @@ -541,7 +599,7 @@ func (c *Consumer) prefetchMsgs() {
if _, ok := c.conn.prefetchedMsgs.msgs[lowerCaseStationName][c.ConsumerGroup]; !ok {
c.conn.prefetchedMsgs.msgs[lowerCaseStationName][c.ConsumerGroup] = make([]*Msg, 0)
}
msgs, err := c.fetchSubscriprionWithTimeout()
msgs, err := c.fetchSubscriprionWithTimeout(partitionKey)
if err == nil {
c.conn.prefetchedMsgs.msgs[lowerCaseStationName][c.ConsumerGroup] = append(c.conn.prefetchedMsgs.msgs[lowerCaseStationName][c.ConsumerGroup], msgs...)
}
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ require (
github.com/nats-io/nats-server/v2 v2.9.5 // indirect
github.com/nats-io/nkeys v0.4.4 // indirect
github.com/nats-io/nuid v1.0.1 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
golang.org/x/crypto v0.6.0 // indirect
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
golang.org/x/sys v0.5.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ github.com/santhosh-tekuri/jsonschema/v5 v5.1.0 h1:wSUNu/w/7OQ0Y3NVnfTU5uxzXY4uM
github.com/santhosh-tekuri/jsonschema/v5 v5.1.0/go.mod h1:FKdcjfQW6rpZSnxxUvEA5H/cDPdvJ/SZJQLWWXWGrZ0=
github.com/santhosh-tekuri/jsonschema/v5 v5.3.0 h1:uIkTLo0AGRc8l7h5l9r+GcYi9qfVPt6lD4/bhmzfiKo=
github.com/santhosh-tekuri/jsonschema/v5 v5.3.0/go.mod h1:FKdcjfQW6rpZSnxxUvEA5H/cDPdvJ/SZJQLWWXWGrZ0=
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
Expand Down
30 changes: 24 additions & 6 deletions producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,11 @@ type Headers struct {

// ProduceOpts - configuration options for produce operations.
type ProduceOpts struct {
Message any
AckWaitSec int
MsgHeaders Headers
AsyncProduce bool
Message any
AckWaitSec int
MsgHeaders Headers
AsyncProduce bool
ProducerPartitionKey string
}

// ProduceOpt - a function on the options for produce operations.
Expand Down Expand Up @@ -380,11 +381,20 @@ func (opts *ProduceOpts) produce(p *Producer) error {

var streamName string
sn := getInternalName(p.stationName)

if len(p.conn.stationPartitions[sn].PartitionsList) == 1 {
streamName = fmt.Sprintf("%v$%v", sn, p.conn.stationPartitions[sn].PartitionsList[0])
} else if len(p.conn.stationPartitions[sn].PartitionsList) > 1 {
partitionNumber := p.PartitionGenerator.Next()
streamName = fmt.Sprintf("%v$%v", sn, partitionNumber)
if opts.ProducerPartitionKey != "" {
partitionNumber, err := p.conn.GetPartitionFromKey(opts.ProducerPartitionKey, sn)
if err != nil {
return memphisError(fmt.Errorf("failed to get partition from key"))
}
streamName = fmt.Sprintf("%v$%v", sn, partitionNumber)
} else {
partitionNumber := p.PartitionGenerator.Next()
streamName = fmt.Sprintf("%v$%v", sn, partitionNumber)
}
} else {
streamName = sn
}
Expand Down Expand Up @@ -544,6 +554,14 @@ func AckWaitSec(ackWaitSec int) ProduceOpt {
}
}

// ProducerPartitionKey - set a partition key for a message
func ProducerPartitionKey(partitionKey string) ProduceOpt {
return func(opts *ProduceOpts) error {
opts.ProducerPartitionKey = partitionKey
return nil
}
}

// MsgHeaders - set headers to a message
func MsgHeaders(hdrs Headers) ProduceOpt {
return func(opts *ProduceOpts) error {
Expand Down

0 comments on commit afedbb6

Please sign in to comment.