Skip to content

Commit

Permalink
Merge pull request #1023 from bobrik/multiple-record-batches
Browse files Browse the repository at this point in the history
Support multiple record batches, closes #1022
  • Loading branch information
eapache authored Jan 22, 2018
2 parents 1abfd98 + 3be0b7c commit 0f4f8ca
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 43 deletions.
35 changes: 27 additions & 8 deletions consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,12 +570,12 @@ func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*Consu
return nil, block.Err
}

nRecs, err := block.Records.numRecords()
nRecs, err := block.numRecords()
if err != nil {
return nil, err
}
if nRecs == 0 {
partialTrailingMessage, err := block.Records.isPartial()
partialTrailingMessage, err := block.isPartial()
if err != nil {
return nil, err
}
Expand All @@ -601,14 +601,33 @@ func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*Consu
child.fetchSize = child.conf.Consumer.Fetch.Default
atomic.StoreInt64(&child.highWaterMarkOffset, block.HighWaterMarkOffset)

if control, err := block.Records.isControl(); err != nil || control {
return nil, err
}
messages := []*ConsumerMessage{}
for _, records := range block.RecordsSet {
if control, err := records.isControl(); err != nil || control {
continue
}

switch records.recordsType {
case legacyRecords:
messageSetMessages, err := child.parseMessages(records.msgSet)
if err != nil {
return nil, err
}

if block.Records.recordsType == legacyRecords {
return child.parseMessages(block.Records.msgSet)
messages = append(messages, messageSetMessages...)
case defaultRecords:
recordBatchMessages, err := child.parseRecords(records.recordBatch)
if err != nil {
return nil, err
}

messages = append(messages, recordBatchMessages...)
default:
return nil, fmt.Errorf("unknown records type: %v", records.recordsType)
}
}
return child.parseRecords(block.Records.recordBatch)

return messages, nil
}

// brokerConsumer
Expand Down
98 changes: 79 additions & 19 deletions fetch_response.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package sarama

import "time"
import (
"time"
)

type AbortedTransaction struct {
ProducerID int64
Expand Down Expand Up @@ -31,7 +33,9 @@ type FetchResponseBlock struct {
HighWaterMarkOffset int64
LastStableOffset int64
AbortedTransactions []*AbortedTransaction
Records Records
Records *Records // deprecated: use FetchResponseBlock.Records
RecordsSet []*Records
Partial bool
}

func (b *FetchResponseBlock) decode(pd packetDecoder, version int16) (err error) {
Expand Down Expand Up @@ -79,15 +83,69 @@ func (b *FetchResponseBlock) decode(pd packetDecoder, version int16) (err error)
if err != nil {
return err
}
if recordsSize > 0 {
if err = b.Records.decode(recordsDecoder); err != nil {

b.RecordsSet = []*Records{}

for recordsDecoder.remaining() > 0 {
records := &Records{}
if err := records.decode(recordsDecoder); err != nil {
// If we have at least one decoded records, this is not an error
if err == ErrInsufficientData {
if len(b.RecordsSet) == 0 {
b.Partial = true
}
break
}
return err
}

partial, err := records.isPartial()
if err != nil {
return err
}

// If we have at least one full records, we skip incomplete ones
if partial && len(b.RecordsSet) > 0 {
break
}

b.RecordsSet = append(b.RecordsSet, records)

if b.Records == nil {
b.Records = records
}
}

return nil
}

func (b *FetchResponseBlock) numRecords() (int, error) {
sum := 0

for _, records := range b.RecordsSet {
count, err := records.numRecords()
if err != nil {
return 0, err
}

sum += count
}

return sum, nil
}

func (b *FetchResponseBlock) isPartial() (bool, error) {
if b.Partial {
return true, nil
}

if len(b.RecordsSet) == 1 {
return b.RecordsSet[0].isPartial()
}

return false, nil
}

func (b *FetchResponseBlock) encode(pe packetEncoder, version int16) (err error) {
pe.putInt16(int16(b.Err))

Expand All @@ -107,9 +165,11 @@ func (b *FetchResponseBlock) encode(pe packetEncoder, version int16) (err error)
}

pe.push(&lengthField{})
err = b.Records.encode(pe)
if err != nil {
return err
for _, records := range b.RecordsSet {
err = records.encode(pe)
if err != nil {
return err
}
}
return pe.pop()
}
Expand Down Expand Up @@ -289,33 +349,33 @@ func (r *FetchResponse) AddMessage(topic string, partition int32, key, value Enc
kb, vb := encodeKV(key, value)
msg := &Message{Key: kb, Value: vb}
msgBlock := &MessageBlock{Msg: msg, Offset: offset}
set := frb.Records.msgSet
if set == nil {
set = &MessageSet{}
frb.Records = newLegacyRecords(set)
if len(frb.RecordsSet) == 0 {
records := newLegacyRecords(&MessageSet{})
frb.RecordsSet = []*Records{&records}
}
set := frb.RecordsSet[0].msgSet
set.Messages = append(set.Messages, msgBlock)
}

func (r *FetchResponse) AddRecord(topic string, partition int32, key, value Encoder, offset int64) {
frb := r.getOrCreateBlock(topic, partition)
kb, vb := encodeKV(key, value)
rec := &Record{Key: kb, Value: vb, OffsetDelta: offset}
batch := frb.Records.recordBatch
if batch == nil {
batch = &RecordBatch{Version: 2}
frb.Records = newDefaultRecords(batch)
if len(frb.RecordsSet) == 0 {
records := newDefaultRecords(&RecordBatch{Version: 2})
frb.RecordsSet = []*Records{&records}
}
batch := frb.RecordsSet[0].recordBatch
batch.addRecord(rec)
}

func (r *FetchResponse) SetLastOffsetDelta(topic string, partition int32, offset int32) {
frb := r.getOrCreateBlock(topic, partition)
batch := frb.Records.recordBatch
if batch == nil {
batch = &RecordBatch{Version: 2}
frb.Records = newDefaultRecords(batch)
if len(frb.RecordsSet) == 0 {
records := newDefaultRecords(&RecordBatch{Version: 2})
frb.RecordsSet = []*Records{&records}
}
batch := frb.RecordsSet[0].recordBatch
batch.LastOffsetDelta = offset
}

Expand Down
18 changes: 9 additions & 9 deletions fetch_response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,22 +117,22 @@ func TestOneMessageFetchResponse(t *testing.T) {
if block.HighWaterMarkOffset != 0x10101010 {
t.Error("Decoding didn't produce correct high water mark offset.")
}
partial, err := block.Records.isPartial()
partial, err := block.isPartial()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if partial {
t.Error("Decoding detected a partial trailing message where there wasn't one.")
}

n, err := block.Records.numRecords()
n, err := block.numRecords()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n != 1 {
t.Fatal("Decoding produced incorrect number of messages.")
}
msgBlock := block.Records.msgSet.Messages[0]
msgBlock := block.RecordsSet[0].msgSet.Messages[0]
if msgBlock.Offset != 0x550000 {
t.Error("Decoding produced incorrect message offset.")
}
Expand Down Expand Up @@ -170,22 +170,22 @@ func TestOneRecordFetchResponse(t *testing.T) {
if block.HighWaterMarkOffset != 0x10101010 {
t.Error("Decoding didn't produce correct high water mark offset.")
}
partial, err := block.Records.isPartial()
partial, err := block.isPartial()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if partial {
t.Error("Decoding detected a partial trailing record where there wasn't one.")
}

n, err := block.Records.numRecords()
n, err := block.numRecords()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n != 1 {
t.Fatal("Decoding produced incorrect number of records.")
}
rec := block.Records.recordBatch.Records[0]
rec := block.RecordsSet[0].recordBatch.Records[0]
if !bytes.Equal(rec.Key, []byte{0x01, 0x02, 0x03, 0x04}) {
t.Error("Decoding produced incorrect record key.")
}
Expand Down Expand Up @@ -216,22 +216,22 @@ func TestOneMessageFetchResponseV4(t *testing.T) {
if block.HighWaterMarkOffset != 0x10101010 {
t.Error("Decoding didn't produce correct high water mark offset.")
}
partial, err := block.Records.isPartial()
partial, err := block.isPartial()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if partial {
t.Error("Decoding detected a partial trailing record where there wasn't one.")
}

n, err := block.Records.numRecords()
n, err := block.numRecords()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n != 1 {
t.Fatal("Decoding produced incorrect number of records.")
}
msgBlock := block.Records.msgSet.Messages[0]
msgBlock := block.RecordsSet[0].msgSet.Messages[0]
if msgBlock.Offset != 0x550000 {
t.Error("Decoding produced incorrect message offset.")
}
Expand Down
9 changes: 9 additions & 0 deletions message_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ func (ms *MessageSet) decode(pd packetDecoder) (err error) {
ms.Messages = nil

for pd.remaining() > 0 {
magic, err := magicValue(pd)
if err != nil {
return err
}

if magic > 1 {
return nil
}

msb := new(MessageBlock)
err = msb.decode(pd)
switch err {
Expand Down
20 changes: 13 additions & 7 deletions records.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,12 @@ func (r *Records) encode(pe packetEncoder) error {
}
return r.recordBatch.encode(pe)
}

return fmt.Errorf("unknown records type: %v", r.recordsType)
}

func (r *Records) setTypeFromMagic(pd packetDecoder) error {
dec, err := pd.peek(magicOffset, magicLength)
if err != nil {
return err
}

magic, err := dec.getInt8()
magic, err := magicValue(pd)
if err != nil {
return err
}
Expand All @@ -80,13 +76,14 @@ func (r *Records) setTypeFromMagic(pd packetDecoder) error {
if magic < 2 {
r.recordsType = legacyRecords
}

return nil
}

func (r *Records) decode(pd packetDecoder) error {
if r.recordsType == unknownRecords {
if err := r.setTypeFromMagic(pd); err != nil {
return nil
return err
}
}

Expand Down Expand Up @@ -165,3 +162,12 @@ func (r *Records) isControl() (bool, error) {
}
return false, fmt.Errorf("unknown records type: %v", r.recordsType)
}

func magicValue(pd packetDecoder) (int8, error) {
dec, err := pd.peek(magicOffset, magicLength)
if err != nil {
return 0, err
}

return dec.getInt8()
}

0 comments on commit 0f4f8ca

Please sign in to comment.