Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

producer / consumer with partition key #121

Merged
merged 7 commits into from
Sep 10, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ import (

"github.com/gofrs/uuid"
"github.com/nats-io/nats.go"
"github.com/spaolacci/murmur3"
)

const (
sdkClientsUpdatesSubject = "$memphis_sdk_clients_updates"
maxBatchSize = 5000
memphisGlobalAccountName = "$memphis"
SEED = 1234
)

var stationUpdatesSubsLock sync.Mutex
Expand Down Expand Up @@ -94,6 +96,7 @@ type FetchOpts struct {
StartConsumeFromSequence uint64
LastMessages int64
Prefetch bool
FetchPartitionKey string
}

// getDefaultConsumerOptions - returns default configuration options for consumers.
Expand All @@ -109,6 +112,7 @@ func getDefaultFetchOptions() FetchOpts {
StartConsumeFromSequence: 1,
LastMessages: -1,
Prefetch: false,
FetchPartitionKey: "",
}
}

Expand Down Expand Up @@ -798,7 +802,7 @@ func (c *Conn) FetchMessages(stationName string, consumerName string, opts ...Fe
} else {
consumer = cons
}
msgs, err := consumer.Fetch(defaultOpts.BatchSize, defaultOpts.Prefetch)
msgs, err := consumer.Fetch(defaultOpts.BatchSize, defaultOpts.Prefetch, ConsumerPartitionKey(defaultOpts.FetchPartitionKey))
if err != nil {
return nil, err
}
Expand All @@ -813,6 +817,14 @@ func FetchConsumerGroup(cg string) FetchOpt {
}
}

// PartitionKey - partition key to consume from.
func FetchPartitionKey(partitionKey string) FetchOpt {
return func(opts *FetchOpts) error {
opts.FetchPartitionKey = partitionKey
return nil
}
}

// BatchSize - pull batch size.
func FetchBatchSize(batchSize int) FetchOpt {
return func(opts *FetchOpts) error {
Expand Down Expand Up @@ -882,3 +894,13 @@ func init() {
applicationId = appId.String()
}
}

func (c *Conn) GetPartitionFromKey(key string, stationName string) (int, error) {
daniel-davidd marked this conversation as resolved.
Show resolved Hide resolved
mur3 := murmur3.New32WithSeed(SEED)
_, err := mur3.Write([]byte(key))
if err != nil {
return -1, err
}
PartitionIndex := int(mur3.Sum32()) % len(c.stationPartitions[stationName].PartitionsList)
return c.stationPartitions[stationName].PartitionsList[PartitionIndex], nil
}
94 changes: 75 additions & 19 deletions consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ func (opts *ConsumerOpts) createConsumer(c *Conn) (*Consumer, error) {
sn := getInternalName(consumer.stationName)

durable := getInternalName(consumer.ConsumerGroup)

if len(consumer.conn.stationPartitions[sn].PartitionsList) == 0 {
consumer.subscriptions = make(map[int]*nats.Subscription, 1)
subj := sn + ".final"
Expand Down Expand Up @@ -382,9 +383,11 @@ func (c *Consumer) pingConsumer() {
}(sub)
}
wg.Wait()
if strings.Contains(generalErr.Error(), "consumer not found") || strings.Contains(generalErr.Error(), "stream not found") {
c.subscriptionActive = false
c.callErrHandler(ConsumerErrStationUnreachable)
if generalErr != nil {
if strings.Contains(generalErr.Error(), "consumer not found") || strings.Contains(generalErr.Error(), "stream not found") {
c.subscriptionActive = false
c.callErrHandler(ConsumerErrStationUnreachable)
}
}
case <-c.pingQuit:
ticker.Stop()
Expand All @@ -401,11 +404,42 @@ func (c *Consumer) SetContext(ctx context.Context) {
// ConsumeHandler - handler for consumed messages
type ConsumeHandler func([]*Msg, error, context.Context)

type ConsumingOpts struct {
daniel-davidd marked this conversation as resolved.
Show resolved Hide resolved
ConsumerPartitionKey string
}

type ConsumingOpt func(*ConsumingOpts) error

func ConsumerPartitionKey(ConsumerPartitionKey string) ConsumingOpt {
return func(opts *ConsumingOpts) error {
opts.ConsumerPartitionKey = ConsumerPartitionKey
return nil
}
}

func getDefaultConsumingOptions() ConsumingOpts {
return ConsumingOpts{
ConsumerPartitionKey: "",
}
}

// Consumer.Consume - start consuming messages according to the interval configured in the consumer object.
// When a batch is consumed the handlerFunc will be called.
func (c *Consumer) Consume(handlerFunc ConsumeHandler) error {
go func(c *Consumer) {
msgs, err := c.fetchSubscription()
func (c *Consumer) Consume(handlerFunc ConsumeHandler, opts ...ConsumingOpt) error {

defaultOpts := getDefaultConsumingOptions()

for _, opt := range opts {
if opt != nil {
if err := opt(&defaultOpts); err != nil {
return memphisError(err)
}
}
}

go func(c *Consumer, partitionKey string) {

msgs, err := c.fetchSubscription(partitionKey)
handlerFunc(msgs, memphisError(err), c.context)
c.dlsHandlerFunc = handlerFunc
ticker := time.NewTicker(c.PullInterval)
Expand All @@ -421,13 +455,13 @@ func (c *Consumer) Consume(handlerFunc ConsumeHandler) error {

select {
case <-ticker.C:
msgs, err := c.fetchSubscription()
msgs, err := c.fetchSubscription(partitionKey)
handlerFunc(msgs, memphisError(err), nil)
case <-c.consumeQuit:
return
}
}
}(c)
}(c, defaultOpts.ConsumerPartitionKey)
c.consumeActive = true
return nil
}
Expand All @@ -442,15 +476,26 @@ func (c *Consumer) StopConsume() {
c.consumeActive = false
}

func (c *Consumer) fetchSubscription() ([]*Msg, error) {
func (c *Consumer) fetchSubscription(partitionKey string) ([]*Msg, error) {

if !c.subscriptionActive {
return nil, memphisError(errors.New("station unreachable"))
}
wrappedMsgs := make([]*Msg, 0, c.BatchSize)
partitionNumber := 1

if len(c.subscriptions) > 1 {
partitionNumber = c.PartitionGenerator.Next()
if partitionKey != "" {
partitionFromKey, err := c.conn.GetPartitionFromKey(partitionKey, c.stationName)
if err != nil {
return nil, memphisError(err)
}
partitionNumber = partitionFromKey
} else {
partitionNumber = c.PartitionGenerator.Next()
}
}

msgs, err := c.subscriptions[partitionNumber].Fetch(c.BatchSize)
if err != nil && err != nats.ErrTimeout {
c.subscriptionActive = false
Expand All @@ -468,13 +513,14 @@ type fetchResult struct {
err error
}

func (c *Consumer) fetchSubscriprionWithTimeout() ([]*Msg, error) {
func (c *Consumer) fetchSubscriprionWithTimeout(partitionKey string) ([]*Msg, error) {
timeoutDuration := c.BatchMaxTimeToWait
out := make(chan fetchResult, 1)
go func() {
msgs, err := c.fetchSubscription()

go func(partitionKey string) {
msgs, err := c.fetchSubscription(partitionKey)
out <- fetchResult{msgs: msgs, err: memphisError(err)}
}()
}(partitionKey)
select {
case <-time.After(timeoutDuration):
return nil, memphisError(errors.New("fetch timed out"))
Expand All @@ -485,11 +531,21 @@ func (c *Consumer) fetchSubscriprionWithTimeout() ([]*Msg, error) {
}

// Fetch - immediately fetch a batch of messages.
func (c *Consumer) Fetch(batchSize int, prefetch bool) ([]*Msg, error) {
func (c *Consumer) Fetch(batchSize int, prefetch bool, opts ...ConsumingOpt) ([]*Msg, error) {
if batchSize > maxBatchSize {
return nil, memphisError(errors.New("Batch size can not be greater than " + strconv.Itoa(maxBatchSize)))
}

defaultOpts := getDefaultConsumingOptions()

for _, opt := range opts {
if opt != nil {
if err := opt(&defaultOpts); err != nil {
return nil, memphisError(err)
}
}
}

c.BatchSize = batchSize
var msgs []*Msg
if len(c.dlsMsgs) > 0 {
Expand Down Expand Up @@ -523,15 +579,15 @@ func (c *Consumer) Fetch(batchSize int, prefetch bool) ([]*Msg, error) {
}
c.conn.prefetchedMsgs.lock.Unlock()
if prefetch {
go c.prefetchMsgs()
go c.prefetchMsgs(defaultOpts.ConsumerPartitionKey)
}
if len(msgs) > 0 {
return msgs, nil
}
return c.fetchSubscriprionWithTimeout()
return c.fetchSubscriprionWithTimeout(defaultOpts.ConsumerPartitionKey)
}

func (c *Consumer) prefetchMsgs() {
func (c *Consumer) prefetchMsgs(partitionKey string) {
c.conn.prefetchedMsgs.lock.Lock()
defer c.conn.prefetchedMsgs.lock.Unlock()
lowerCaseStationName := getLowerCaseName(c.stationName)
Expand All @@ -541,7 +597,7 @@ func (c *Consumer) prefetchMsgs() {
if _, ok := c.conn.prefetchedMsgs.msgs[lowerCaseStationName][c.ConsumerGroup]; !ok {
c.conn.prefetchedMsgs.msgs[lowerCaseStationName][c.ConsumerGroup] = make([]*Msg, 0)
}
msgs, err := c.fetchSubscriprionWithTimeout()
msgs, err := c.fetchSubscriprionWithTimeout(partitionKey)
if err == nil {
c.conn.prefetchedMsgs.msgs[lowerCaseStationName][c.ConsumerGroup] = append(c.conn.prefetchedMsgs.msgs[lowerCaseStationName][c.ConsumerGroup], msgs...)
}
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ require (
github.com/nats-io/nats-server/v2 v2.9.5 // indirect
github.com/nats-io/nkeys v0.4.4 // indirect
github.com/nats-io/nuid v1.0.1 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
golang.org/x/crypto v0.6.0 // indirect
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
golang.org/x/sys v0.5.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ github.com/santhosh-tekuri/jsonschema/v5 v5.1.0 h1:wSUNu/w/7OQ0Y3NVnfTU5uxzXY4uM
github.com/santhosh-tekuri/jsonschema/v5 v5.1.0/go.mod h1:FKdcjfQW6rpZSnxxUvEA5H/cDPdvJ/SZJQLWWXWGrZ0=
github.com/santhosh-tekuri/jsonschema/v5 v5.3.0 h1:uIkTLo0AGRc8l7h5l9r+GcYi9qfVPt6lD4/bhmzfiKo=
github.com/santhosh-tekuri/jsonschema/v5 v5.3.0/go.mod h1:FKdcjfQW6rpZSnxxUvEA5H/cDPdvJ/SZJQLWWXWGrZ0=
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
Expand Down
30 changes: 24 additions & 6 deletions producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,12 @@ type Headers struct {

// ProduceOpts - configuration options for produce operations.
type ProduceOpts struct {
Message any
AckWaitSec int
MsgHeaders Headers
AsyncProduce bool
Message any
AckWaitSec int
MsgHeaders Headers
AsyncProduce bool
ProducerPartitionKey string
PartitionIndex []int //ask asulin
}

// ProduceOpt - a function on the options for produce operations.
Expand Down Expand Up @@ -380,11 +382,20 @@ func (opts *ProduceOpts) produce(p *Producer) error {

var streamName string
sn := getInternalName(p.stationName)

if len(p.conn.stationPartitions[sn].PartitionsList) == 1 {
streamName = fmt.Sprintf("%v$%v", sn, p.conn.stationPartitions[sn].PartitionsList[0])
} else if len(p.conn.stationPartitions[sn].PartitionsList) > 1 {
partitionNumber := p.PartitionGenerator.Next()
streamName = fmt.Sprintf("%v$%v", sn, partitionNumber)
if opts.ProducerPartitionKey != "" {
partitionNumber, err := p.conn.GetPartitionFromKey(opts.ProducerPartitionKey, sn)
if err != nil {
return memphisError(fmt.Errorf("failed to get partition from key"))
}
streamName = fmt.Sprintf("%v$%v", sn, partitionNumber)
} else {
partitionNumber := p.PartitionGenerator.Next()
streamName = fmt.Sprintf("%v$%v", sn, partitionNumber)
}
} else {
streamName = sn
}
Expand Down Expand Up @@ -544,6 +555,13 @@ func AckWaitSec(ackWaitSec int) ProduceOpt {
}
}

func ProducerPartitionKey(partitionKey string) ProduceOpt {
daniel-davidd marked this conversation as resolved.
Show resolved Hide resolved
return func(opts *ProduceOpts) error {
opts.ProducerPartitionKey = partitionKey
return nil
}
}

// MsgHeaders - set headers to a message
func MsgHeaders(hdrs Headers) ProduceOpt {
return func(opts *ProduceOpts) error {
Expand Down