diff --git a/cmd/kafka-consumer/main.go b/cmd/kafka-consumer/main.go index 80c0b4eb3ea..7dff0f9662d 100644 --- a/cmd/kafka-consumer/main.go +++ b/cmd/kafka-consumer/main.go @@ -512,21 +512,26 @@ func (c *Consumer) ConsumeClaim(session sarama.ConsumerGroupSession, claim saram panic("sink should initialized") } + var ( + decoder codec.RowEventDecoder + err error + ) + switch c.protocol { + case config.ProtocolOpen, config.ProtocolDefault: + decoder = open.NewBatchDecoder() + case config.ProtocolCanalJSON: + decoder = canal.NewBatchDecoder(c.enableTiDBExtension, "") + default: + log.Panic("Protocol not supported", zap.Any("Protocol", c.protocol)) + } + if err != nil { + return errors.Trace(err) + } + eventGroups := make(map[int64]*eventsGroup) for message := range claim.Messages() { - var ( - decoder codec.RowEventDecoder - err error - ) - switch c.protocol { - case config.ProtocolOpen, config.ProtocolDefault: - decoder, err = open.NewBatchDecoder(message.Key, message.Value) - case config.ProtocolCanalJSON: - decoder = canal.NewBatchDecoder(message.Value, c.enableTiDBExtension, "") - default: - log.Panic("Protocol not supported", zap.Any("Protocol", c.protocol)) - } - if err != nil { + if err := decoder.AddKeyValue(message.Key, message.Value); err != nil { + log.Error("add key value to the decoder failed", zap.Error(err)) return errors.Trace(err) } diff --git a/cmd/storage-consumer/main.go b/cmd/storage-consumer/main.go index 2a2e118ee68..017fe4db627 100644 --- a/cmd/storage-consumer/main.go +++ b/cmd/storage-consumer/main.go @@ -444,7 +444,11 @@ func (c *consumer) emitDMLEvents( case config.ProtocolCanalJSON: // Always enable tidb extension for canal-json protocol // because we need to get the commit ts from the extension field. - decoder = canal.NewBatchDecoder(content, true, c.codecCfg.Terminator) + decoder = canal.NewBatchDecoder(true, c.codecCfg.Terminator) + err := decoder.AddKeyValue(nil, content) + if err != nil { + return errors.Trace(err) + } } cnt := 0 diff --git a/pkg/sink/codec/avro/avro.go b/pkg/sink/codec/avro/avro.go index 966b7e3f041..60aec3d016d 100644 --- a/pkg/sink/codec/avro/avro.go +++ b/pkg/sink/codec/avro/avro.go @@ -51,11 +51,11 @@ type BatchEncoder struct { // Options is used to initialize the encoder, control the encoding behavior. type Options struct { - enableTiDBExtension bool - enableRowChecksum bool + EnableTiDBExtension bool + EnableRowChecksum bool - decimalHandlingMode string - bigintUnsignedHandlingMode string + DecimalHandlingMode string + BigintUnsignedHandlingMode string } type avroEncodeInput struct { @@ -194,8 +194,8 @@ func (a *BatchEncoder) avroEncode( colInfos: e.ColInfos, } - enableTiDBExtension = a.enableTiDBExtension - enableRowLevelChecksum = a.enableRowChecksum + enableTiDBExtension = a.EnableTiDBExtension + enableRowLevelChecksum = a.EnableRowChecksum schemaManager = a.valueSchemaManager if e.IsInsert() { operation = insertOperation @@ -220,8 +220,8 @@ func (a *BatchEncoder) avroEncode( input, enableTiDBExtension, enableRowLevelChecksum, - a.decimalHandlingMode, - a.bigintUnsignedHandlingMode, + a.DecimalHandlingMode, + a.BigintUnsignedHandlingMode, ) if err != nil { log.Error("AvroEventBatchEncoder: generating schema failed", zap.Error(err)) @@ -245,8 +245,8 @@ func (a *BatchEncoder) avroEncode( e.CommitTs, operation, enableTiDBExtension, - a.decimalHandlingMode, - a.bigintUnsignedHandlingMode, + a.DecimalHandlingMode, + a.BigintUnsignedHandlingMode, ) if err != nil { log.Error("AvroEventBatchEncoder: converting to native failed", zap.Error(err)) @@ -924,10 +924,10 @@ func (b *batchEncoderBuilder) Build() codec.RowEventEncoder { valueSchemaManager: b.valueSchemaManager, result: make([]*common.Message, 0, 1), Options: &Options{ - enableTiDBExtension: b.config.EnableTiDBExtension, - enableRowChecksum: b.config.EnableRowChecksum, - decimalHandlingMode: b.config.AvroDecimalHandlingMode, - bigintUnsignedHandlingMode: b.config.AvroBigintUnsignedHandlingMode, + EnableTiDBExtension: b.config.EnableTiDBExtension, + EnableRowChecksum: b.config.EnableRowChecksum, + DecimalHandlingMode: b.config.AvroDecimalHandlingMode, + BigintUnsignedHandlingMode: b.config.AvroBigintUnsignedHandlingMode, }, } diff --git a/pkg/sink/codec/avro/avro_test.go b/pkg/sink/codec/avro/avro_test.go index 56b8609fc31..e83e1b0d385 100644 --- a/pkg/sink/codec/avro/avro_test.go +++ b/pkg/sink/codec/avro/avro_test.go @@ -65,9 +65,9 @@ func setupEncoderAndSchemaRegistry( keySchemaManager: keyManager, result: make([]*common.Message, 0, 1), Options: &Options{ - enableTiDBExtension: enableTiDBExtension, - decimalHandlingMode: decimalHandlingMode, - bigintUnsignedHandlingMode: bigintUnsignedHandlingMode, + EnableTiDBExtension: enableTiDBExtension, + DecimalHandlingMode: decimalHandlingMode, + BigintUnsignedHandlingMode: bigintUnsignedHandlingMode, }, }, nil } diff --git a/pkg/sink/codec/builder/codec_test.go b/pkg/sink/codec/builder/codec_test.go index 914aa7370d1..114234eaad8 100644 --- a/pkg/sink/codec/builder/codec_test.go +++ b/pkg/sink/codec/builder/codec_test.go @@ -277,9 +277,9 @@ func BenchmarkProtobuf2Encoding(b *testing.B) { func BenchmarkCraftDecoding(b *testing.B) { allocator := craft.NewSliceAllocator(128) for i := 0; i < b.N; i++ { + decoder := craft.NewBatchDecoderWithAllocator(allocator) for _, message := range codecCraftEncodedRowChanges { - if decoder, err := craft.NewBatchDecoderWithAllocator( - message.Value, allocator); err != nil { + if err := decoder.AddKeyValue(message.Key, message.Value); err != nil { panic(err) } else { for { @@ -299,7 +299,8 @@ func BenchmarkCraftDecoding(b *testing.B) { func BenchmarkJsonDecoding(b *testing.B) { for i := 0; i < b.N; i++ { for _, message := range codecJSONEncodedRowChanges { - if decoder, err := open.NewBatchDecoder(message.Key, message.Value); err != nil { + decoder := open.NewBatchDecoder() + if err := decoder.AddKeyValue(message.Key, message.Value); err != nil { panic(err) } else { for { diff --git a/pkg/sink/codec/canal/canal_json_decoder.go b/pkg/sink/codec/canal/canal_json_decoder.go index 60697914c11..31b2ca17b46 100644 --- a/pkg/sink/codec/canal/canal_json_decoder.go +++ b/pkg/sink/codec/canal/canal_json_decoder.go @@ -33,18 +33,22 @@ type batchDecoder struct { } // NewBatchDecoder return a decoder for canal-json -func NewBatchDecoder(data []byte, +func NewBatchDecoder( enableTiDBExtension bool, terminator string, ) codec.RowEventDecoder { return &batchDecoder{ - data: data, - msg: nil, enableTiDBExtension: enableTiDBExtension, terminator: terminator, } } +// AddKeyValue implements the RowEventDecoder interface +func (b *batchDecoder) AddKeyValue(_, value []byte) error { + b.data = value + return nil +} + // HasNext implements the RowEventDecoder interface func (b *batchDecoder) HasNext() (model.MessageType, bool, error) { var ( diff --git a/pkg/sink/codec/canal/canal_json_decoder_test.go b/pkg/sink/codec/canal/canal_json_decoder_test.go index c1783f4b1ad..67bdd55cf56 100644 --- a/pkg/sink/codec/canal/canal_json_decoder_test.go +++ b/pkg/sink/codec/canal/canal_json_decoder_test.go @@ -42,7 +42,9 @@ func TestNewCanalJSONBatchDecoder4RowMessage(t *testing.T) { msg := messages[0] for _, decodeEnable := range []bool{false, true} { - decoder := NewBatchDecoder(msg.Value, decodeEnable, "") + decoder := NewBatchDecoder(decodeEnable, "") + err := decoder.AddKeyValue(msg.Key, msg.Value) + require.NoError(t, err) ty, hasNext, err := decoder.HasNext() require.Nil(t, err) @@ -95,7 +97,9 @@ func TestNewCanalJSONBatchDecoder4DDLMessage(t *testing.T) { require.NotNil(t, result) for _, decodeEnable := range []bool{false, true} { - decoder := NewBatchDecoder(result.Value, decodeEnable, "") + decoder := NewBatchDecoder(decodeEnable, "") + err := decoder.AddKeyValue(nil, result.Value) + require.NoError(t, err) ty, hasNext, err := decoder.HasNext() require.Nil(t, err) @@ -130,7 +134,10 @@ func TestCanalJSONBatchDecoderWithTerminator(t *testing.T) { encodedValue := `{"id":0,"database":"test","table":"employee","pkNames":["id"],"isDdl":false,"type":"INSERT","es":1668067205238,"ts":1668067206650,"sql":"","sqlType":{"FirstName":12,"HireDate":91,"LastName":12,"OfficeLocation":12,"id":4},"mysqlType":{"FirstName":"varchar","HireDate":"date","LastName":"varchar","OfficeLocation":"varchar","id":"int"},"data":[{"FirstName":"Bob","HireDate":"2014-06-04","LastName":"Smith","OfficeLocation":"New York","id":"101"}],"old":null} {"id":0,"database":"test","table":"employee","pkNames":["id"],"isDdl":false,"type":"UPDATE","es":1668067229137,"ts":1668067230720,"sql":"","sqlType":{"FirstName":12,"HireDate":91,"LastName":12,"OfficeLocation":12,"id":4},"mysqlType":{"FirstName":"varchar","HireDate":"date","LastName":"varchar","OfficeLocation":"varchar","id":"int"},"data":[{"FirstName":"Bob","HireDate":"2015-10-08","LastName":"Smith","OfficeLocation":"Los Angeles","id":"101"}],"old":[{"FirstName":"Bob","HireDate":"2014-06-04","LastName":"Smith","OfficeLocation":"New York","id":"101"}]} {"id":0,"database":"test","table":"employee","pkNames":["id"],"isDdl":false,"type":"DELETE","es":1668067230388,"ts":1668067231725,"sql":"","sqlType":{"FirstName":12,"HireDate":91,"LastName":12,"OfficeLocation":12,"id":4},"mysqlType":{"FirstName":"varchar","HireDate":"date","LastName":"varchar","OfficeLocation":"varchar","id":"int"},"data":[{"FirstName":"Bob","HireDate":"2015-10-08","LastName":"Smith","OfficeLocation":"Los Angeles","id":"101"}],"old":null}` - decoder := NewBatchDecoder([]byte(encodedValue), false, "\n") + decoder := NewBatchDecoder(false, "\n") + err := decoder.AddKeyValue(nil, []byte(encodedValue)) + require.NoError(t, err) + cnt := 0 for { tp, hasNext, err := decoder.HasNext() diff --git a/pkg/sink/codec/canal/canal_json_row_event_encoder_test.go b/pkg/sink/codec/canal/canal_json_row_event_encoder_test.go index 96c1de19e2d..ee197839f35 100644 --- a/pkg/sink/codec/canal/canal_json_row_event_encoder_test.go +++ b/pkg/sink/codec/canal/canal_json_row_event_encoder_test.go @@ -218,7 +218,10 @@ func TestEncodeCheckpointEvent(t *testing.T) { } require.NotNil(t, msg) - decoder := NewBatchDecoder(msg.Value, enable, "") + decoder := NewBatchDecoder(enable, "") + + err = decoder.AddKeyValue(msg.Key, msg.Value) + require.NoError(t, err) ty, hasNext, err := decoder.HasNext() require.Nil(t, err) diff --git a/pkg/sink/codec/craft/craft_decoder.go b/pkg/sink/codec/craft/craft_decoder.go index fc351c31487..ab1cd1763b5 100644 --- a/pkg/sink/codec/craft/craft_decoder.go +++ b/pkg/sink/codec/craft/craft_decoder.go @@ -118,27 +118,33 @@ func (b *batchDecoder) NextDDLEvent() (*model.DDLEvent, error) { return event, nil } -// newBatchDecoder creates a new batchDecoder. func newBatchDecoder(bits []byte) (codec.RowEventDecoder, error) { - return NewBatchDecoderWithAllocator(bits, NewSliceAllocator(64)) + decoder := NewBatchDecoderWithAllocator(NewSliceAllocator(64)) + err := decoder.AddKeyValue(nil, bits) + return decoder, err } // NewBatchDecoderWithAllocator creates a new batchDecoder with given allocator. func NewBatchDecoderWithAllocator( - bits []byte, allocator *SliceAllocator, -) (codec.RowEventDecoder, error) { - decoder, err := NewMessageDecoder(bits, allocator) + allocator *SliceAllocator, +) codec.RowEventDecoder { + return &batchDecoder{ + allocator: allocator, + } +} + +// AddKeyValue implements the RowEventDecoder interface +func (b *batchDecoder) AddKeyValue(_, value []byte) error { + decoder, err := NewMessageDecoder(value, b.allocator) if err != nil { - return nil, errors.Trace(err) + return errors.Trace(err) } headers, err := decoder.Headers() if err != nil { - return nil, errors.Trace(err) + return errors.Trace(err) } + b.decoder = decoder + b.headers = headers - return &batchDecoder{ - headers: headers, - decoder: decoder, - allocator: allocator, - }, nil + return nil } diff --git a/pkg/sink/codec/csv/csv_decoder.go b/pkg/sink/codec/csv/csv_decoder.go index 3c9bcb3d038..84e1f29d0fa 100644 --- a/pkg/sink/codec/csv/csv_decoder.go +++ b/pkg/sink/codec/csv/csv_decoder.go @@ -74,6 +74,11 @@ func NewBatchDecoder(ctx context.Context, }, nil } +// AddKeyValue implements the RowEventDecoder interface. +func (b *batchDecoder) AddKeyValue(_, _ []byte) error { + return nil +} + // HasNext implements the RowEventDecoder interface. func (b *batchDecoder) HasNext() (model.MessageType, bool, error) { err := b.parser.ReadRow() diff --git a/pkg/sink/codec/decoder.go b/pkg/sink/codec/decoder.go index 963cbee4f2e..acfdb7ee216 100644 --- a/pkg/sink/codec/decoder.go +++ b/pkg/sink/codec/decoder.go @@ -18,6 +18,11 @@ import "github.com/pingcap/tiflow/cdc/model" // RowEventDecoder is an abstraction for events decoder // this interface is only for testing now type RowEventDecoder interface { + // AddKeyValue add the received key and values to the decoder, + // should be called before `HasNext` + // decoder decode the key and value into the event format. + AddKeyValue(key, value []byte) error + // HasNext returns // 1. the type of the next event // 2. a bool if the next event is exist diff --git a/pkg/sink/codec/internal/batch_tester.go b/pkg/sink/codec/internal/batch_tester.go index 54d2e852230..1ea80cadaad 100644 --- a/pkg/sink/codec/internal/batch_tester.go +++ b/pkg/sink/codec/internal/batch_tester.go @@ -10,6 +10,7 @@ // distributed under the License is distributed on an "AS IS" BASIS, // See the License for the specific language governing permissions and // limitations under the License. + package internal import ( @@ -288,8 +289,10 @@ func (s *BatchTester) TestBatchCodec( res := encoder.Build() require.Len(t, res, 1) require.Equal(t, len(cs), res[0].GetRowsCount()) + decoder, err := newDecoder(res[0].Key, res[0].Value) - require.Nil(t, err) + require.NoError(t, err) + checkRowDecoder(decoder, cs) } } @@ -299,8 +302,10 @@ func (s *BatchTester) TestBatchCodec( msg, err := encoder.EncodeDDLEvent(ddl) require.Nil(t, err) require.NotNil(t, msg) + decoder, err := newDecoder(msg.Key, msg.Value) - require.Nil(t, err) + require.NoError(t, err) + checkDDLDecoder(decoder, cs[i:i+1]) } @@ -312,8 +317,10 @@ func (s *BatchTester) TestBatchCodec( msg, err := encoder.EncodeCheckpointEvent(ts) require.Nil(t, err) require.NotNil(t, msg) + decoder, err := newDecoder(msg.Key, msg.Value) - require.Nil(t, err) + require.NoError(t, err) + checkTSDecoder(decoder, cs[i:i+1]) } } diff --git a/pkg/sink/codec/open/open_protocol_decoder.go b/pkg/sink/codec/open/open_protocol_decoder.go index 34f82adaf23..0fc65a5ccae 100644 --- a/pkg/sink/codec/open/open_protocol_decoder.go +++ b/pkg/sink/codec/open/open_protocol_decoder.go @@ -224,20 +224,44 @@ func (b *BatchDecoder) decodeNextKey() error { } // NewBatchDecoder creates a new BatchDecoder. -func NewBatchDecoder(key []byte, value []byte) (codec.RowEventDecoder, error) { +func NewBatchDecoder() codec.RowEventDecoder { + return &BatchDecoder{} + +} + +// AddKeyValue implements the RowEventDecoder interface +func (b *BatchDecoder) AddKeyValue(key, value []byte) error { + if len(b.keyBytes) != 0 || len(b.valueBytes) != 0 { + return cerror.ErrOpenProtocolCodecInvalidData. + GenWithStack("decoder key and value not nil") + } + version := binary.BigEndian.Uint64(key[:8]) + key = key[8:] + if version != codec.BatchVersion1 { + return cerror.ErrOpenProtocolCodecInvalidData. + GenWithStack("unexpected key format version") + } + + b.keyBytes = key + b.valueBytes = value + + return nil + +} + +// AddKeyValue implements the RowEventDecoder interface +func (b *BatchMixedDecoder) AddKeyValue(key, value []byte) error { + if key != nil || value != nil { + return cerror.ErrOpenProtocolCodecInvalidData. + GenWithStack("decoder key and value not nil") + } version := binary.BigEndian.Uint64(key[:8]) key = key[8:] if version != codec.BatchVersion1 { - return nil, cerror.ErrOpenProtocolCodecInvalidData.GenWithStack("unexpected key format version") - } - // if only decode one byte slice, we choose MixedDecoder - if len(key) > 0 && len(value) == 0 { - return &BatchMixedDecoder{ - mixedBytes: key, - }, nil - } - return &BatchDecoder{ - keyBytes: key, - valueBytes: value, - }, nil + return cerror.ErrOpenProtocolCodecInvalidData. + GenWithStack("unexpected key format version") + } + + b.mixedBytes = key + return nil } diff --git a/pkg/sink/codec/open/open_protocol_encoder_test.go b/pkg/sink/codec/open/open_protocol_encoder_test.go index 809da7a633a..61bdcf32191 100644 --- a/pkg/sink/codec/open/open_protocol_encoder_test.go +++ b/pkg/sink/codec/open/open_protocol_encoder_test.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tiflow/cdc/model" "github.com/pingcap/tiflow/pkg/config" + "github.com/pingcap/tiflow/pkg/sink/codec" "github.com/pingcap/tiflow/pkg/sink/codec/common" "github.com/pingcap/tiflow/pkg/sink/codec/internal" "github.com/stretchr/testify/require" @@ -99,10 +100,11 @@ func TestMaxBatchSize(t *testing.T) { } messages := encoder.Build() + decoder := NewBatchDecoder() sum := 0 for _, msg := range messages { - decoder, err := NewBatchDecoder(msg.Key, msg.Value) - require.Nil(t, err) + err := decoder.AddKeyValue(msg.Key, msg.Value) + require.NoError(t, err) count := 0 for { v, hasNext, err := decoder.HasNext() @@ -206,5 +208,10 @@ func TestOpenProtocolBatchCodec(t *testing.T) { config := common.NewConfig(config.ProtocolOpen).WithMaxMessageBytes(8192) config.MaxBatchSize = 64 tester := internal.NewDefaultBatchTester() - tester.TestBatchCodec(t, NewBatchEncoderBuilder(config), NewBatchDecoder) + tester.TestBatchCodec(t, NewBatchEncoderBuilder(config), + func(key []byte, value []byte) (codec.RowEventDecoder, error) { + decoder := NewBatchDecoder() + err := decoder.AddKeyValue(key, value) + return decoder, err + }) }