diff --git a/cdc/sink/codec/avro.go b/cdc/sink/codec/avro.go index fa6f004ef26..bc6c80c8f7f 100644 --- a/cdc/sink/codec/avro.go +++ b/cdc/sink/codec/avro.go @@ -24,6 +24,7 @@ import ( "strconv" "time" + "github.com/linkedin/goavro/v2" "github.com/pingcap/errors" "github.com/pingcap/log" "github.com/pingcap/tidb/parser/mysql" @@ -227,11 +228,11 @@ func ColumnInfoToAvroSchema(name string, columnInfo []*model.Column) (string, er } field := make(map[string]interface{}) field["name"] = col.Name - if col.Flag.IsHandleKey() { - field["type"] = avroType - } else { + if col.Flag.IsNullable() { field["type"] = []interface{}{"null", avroType} field["default"] = nil + } else { + field["type"] = avroType } top.Fields = append(top.Fields, field) @@ -256,13 +257,12 @@ func rowToAvroNativeData(cols []*model.Column, colInfos []rowcodec.ColInfo, tz * return nil, err } - if col.Flag.IsHandleKey() { + // https://pkg.go.dev/github.com/linkedin/goavro/v2#Union + if col.Flag.IsNullable() { + ret[col.Name] = goavro.Union(str, data) + } else { ret[col.Name] = data - continue } - union := make(map[string]interface{}, 1) - union[str] = data - ret[col.Name] = union } return ret, nil } diff --git a/cdc/sink/codec/avro_test.go b/cdc/sink/codec/avro_test.go index 5e5112553ae..159e3273b04 100644 --- a/cdc/sink/codec/avro_test.go +++ b/cdc/sink/codec/avro_test.go @@ -15,6 +15,7 @@ package codec import ( "context" + "encoding/json" "time" "github.com/linkedin/goavro/v2" @@ -157,6 +158,80 @@ func (s *avroBatchEncoderSuite) TestAvroEncodeOnly(c *check.C) { log.Info("TestAvroEncodeOnly", zap.ByteString("result", txt)) } +func (s *avroBatchEncoderSuite) TestAvroNull(c *check.C) { + defer testleak.AfterTest(c)() + + table := model.TableName{ + Schema: "testdb", + Table: "TestAvroNull", + } + + cols := []*model.Column{ + {Name: "id", Value: int64(1), Flag: model.HandleKeyFlag, Type: mysql.TypeLong}, + {Name: "colNullable", Value: nil, Flag: model.NullableFlag, Type: mysql.TypeLong}, + {Name: "colNotnull", Value: int64(0), Type: mysql.TypeLong}, + {Name: "colNullable1", Value: int64(0), Flag: model.NullableFlag, Type: mysql.TypeLong}, + } + + colInfos := []rowcodec.ColInfo{ + {ID: 1, IsPKHandle: true, VirtualGenCol: false, Ft: types.NewFieldType(mysql.TypeLong)}, + {ID: 2, IsPKHandle: false, VirtualGenCol: false, Ft: types.NewFieldType(mysql.TypeLong)}, + { + ID: 3, IsPKHandle: false, VirtualGenCol: false, + Ft: setFlag(types.NewFieldType(mysql.TypeLong), uint(model.NullableFlag)), + }, + {ID: 4, IsPKHandle: false, VirtualGenCol: false, Ft: types.NewFieldType(mysql.TypeLong)}, + } + + schema, err := ColumnInfoToAvroSchema(table.Table, cols) + c.Assert(err, check.IsNil) + var schemaObj avroSchemaTop + err = json.Unmarshal([]byte(schema), &schemaObj) + c.Assert(err, check.IsNil) + for _, v := range schemaObj.Fields { + if v["name"] == "colNullable" { + c.Assert(v["type"], check.DeepEquals, []interface{}{"null", "int"}) + } + if v["name"] == "colNotnull" { + c.Assert(v["type"], check.Equals, "int") + } + } + + native, err := rowToAvroNativeData(cols, colInfos, time.Local) + c.Assert(err, check.IsNil) + for k, v := range native.(map[string]interface{}) { + if k == "colNullable" { + c.Check(v, check.IsNil) + } + if k == "colNotnull" { + c.Assert(v, check.Equals, int64(0)) + } + if k == "colNullable1" { + c.Assert(v, check.DeepEquals, map[string]interface{}{"int": int64(0)}) + } + } + + avroCodec, err := goavro.NewCodec(schema) + c.Assert(err, check.IsNil) + r, err := avroEncode(&table, s.encoder.valueSchemaManager, 1, cols, colInfos, time.Local) + c.Assert(err, check.IsNil) + + native, _, err = avroCodec.NativeFromBinary(r.data) + c.Check(err, check.IsNil) + c.Check(native, check.NotNil) + for k, v := range native.(map[string]interface{}) { + if k == "colNullable" { + c.Check(v, check.IsNil) + } + if k == "colNotnull" { + c.Assert(v.(int32), check.Equals, int32(0)) + } + if k == "colNullable1" { + c.Assert(v, check.DeepEquals, map[string]interface{}{"int": int32(0)}) + } + } +} + func (s *avroBatchEncoderSuite) TestAvroTimeZone(c *check.C) { defer testleak.AfterTest(c)() @@ -198,7 +273,7 @@ func (s *avroBatchEncoderSuite) TestAvroTimeZone(c *check.C) { res, _, err := avroCodec.NativeFromBinary(r.data) c.Check(err, check.IsNil) c.Check(res, check.NotNil) - actual := (res.(map[string]interface{}))["ts"].(map[string]interface{})["long.timestamp-millis"].(time.Time) + actual := (res.(map[string]interface{}))["ts"].(time.Time) c.Check(actual.Local().Sub(timestamp), check.LessEqual, time.Millisecond) }