diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 4836bb8ad8ea9..b7e5a601a00df 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -829,6 +829,20 @@ func checkColumnDefaultValue(ctx sessionctx.Context, col *table.Column, value in } } } + if value != nil && col.GetType() == mysql.TypeBit { + v, ok := value.(string) + if !ok { + return hasDefaultValue, value, types.ErrInvalidDefault.GenWithStackByArgs(col.Name.O) + } + + uintVal, err := types.BinaryLiteral(v).ToInt(ctx.GetSessionVars().StmtCtx) + if err != nil { + return hasDefaultValue, value, types.ErrInvalidDefault.GenWithStackByArgs(col.Name.O) + } + if col.GetFlen() < 64 && uintVal >= 1<<(uint64(col.GetFlen())) { + return hasDefaultValue, value, types.ErrInvalidDefault.GenWithStackByArgs(col.Name.O) + } + } return hasDefaultValue, value, nil } @@ -4281,13 +4295,14 @@ func setDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu } col.DefaultIsExpr = isSeqExpr } - - if hasDefaultValue, value, err = checkColumnDefaultValue(ctx, col, value); err != nil { - return hasDefaultValue, errors.Trace(err) - } - value, err = convertTimestampDefaultValToUTC(ctx, value, col) - if err != nil { - return hasDefaultValue, errors.Trace(err) + if !col.DefaultIsExpr { + if hasDefaultValue, value, err = checkColumnDefaultValue(ctx, col, value); err != nil { + return hasDefaultValue, errors.Trace(err) + } + value, err = convertTimestampDefaultValToUTC(ctx, value, col) + if err != nil { + return hasDefaultValue, errors.Trace(err) + } } err = setDefaultValueWithBinaryPadding(col, value) if err != nil { diff --git a/ddl/integration_test.go b/ddl/integration_test.go index 34edd8dfe34d6..9f2e74b41192a 100644 --- a/ddl/integration_test.go +++ b/ddl/integration_test.go @@ -144,3 +144,21 @@ func TestDDLOnCachedTable(t *testing.T) { tk.MustExec("alter table t nocache;") tk.MustExec("drop table if exists t;") } + +func TestTooLongDefaultValueForBit(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + + tk.MustExec("use test;") + + tk.MustGetErrCode("create table t(a bit(2) default b'111');", 1067) + tk.MustGetErrCode("create table t(a bit(65) default b'111');", 1439) + tk.MustExec("create table t(a bit(64) default b'1111111111111111111111111111111111111111111111111111111111111111');") + tk.MustExec("drop table t") + tk.MustExec("create table t(a bit(3) default b'111');") + tk.MustExec("drop table t") + tk.MustExec("create table t(a bit(3) default b'000111');") + tk.MustExec("drop table t;") + tk.MustExec("create table t(a bit(32) default b'1111111111111111111111111111111');") +} diff --git a/executor/write_test.go b/executor/write_test.go index 03fc8df97069b..ff1a97e2c7544 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -1947,7 +1947,7 @@ func TestIssue18681(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec("use test") createSQL := `drop table if exists load_data_test; - create table load_data_test (a bit(1),b bit(1),c bit(1),d bit(1));` + create table load_data_test (a bit(1),b bit(1),c bit(1),d bit(1),e bit(32),f bit(1));` tk.MustExec(createSQL) tk.MustExec("load data local infile '/tmp/nonexistence.csv' ignore into table load_data_test") ctx := tk.Session().(sessionctx.Context) @@ -1957,7 +1957,7 @@ func TestIssue18681(t *testing.T) { require.NotNil(t, ld) deleteSQL := "delete from load_data_test" - selectSQL := "select bin(a), bin(b), bin(c), bin(d) from load_data_test;" + selectSQL := "select bin(a), bin(b), bin(c), bin(d), bin(e), bin(f) from load_data_test;" ctx.GetSessionVars().StmtCtx.DupKeyAsWarning = true ctx.GetSessionVars().StmtCtx.BadNullAsWarning = true ld.SetMaxRowsInBatch(20000) @@ -1969,7 +1969,7 @@ func TestIssue18681(t *testing.T) { }() sc.IgnoreTruncate = false tests := []testCase{ - {nil, []byte("true\tfalse\t0\t1\n"), []string{"1|0|0|1"}, nil, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 0"}, + {nil, []byte("true\tfalse\t0\t1\tb'1'\tb'1'\n"), []string{"1|1|1|1|1100010001001110011000100100111|1"}, nil, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 5"}, } checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) require.Equal(t, uint16(0), sc.WarningCount()) diff --git a/types/datum.go b/types/datum.go index 73755f066b8c1..84e094be06cb6 100644 --- a/types/datum.go +++ b/types/datum.go @@ -1553,38 +1553,13 @@ func (d *Datum) ConvertToMysqlYear(sc *stmtctx.StatementContext, target *FieldTy return ret, errors.Trace(err) } -func (d *Datum) convertStringToMysqlBit(sc *stmtctx.StatementContext) (uint64, error) { - bitStr, err := ParseBitStr(BinaryLiteral(d.b).ToString()) - if err != nil { - // It cannot be converted to bit type, so we need to convert it to int type. - return BinaryLiteral(d.b).ToInt(sc) - } - return bitStr.ToInt(sc) -} - func (d *Datum) convertToMysqlBit(sc *stmtctx.StatementContext, target *FieldType) (Datum, error) { var ret Datum var uintValue uint64 var err error switch d.k { - case KindBytes: + case KindString, KindBytes: uintValue, err = BinaryLiteral(d.b).ToInt(sc) - case KindString: - // For single bit value, we take string like "true", "1" as 1, and "false", "0" as 0, - // this behavior is not documented in MySQL, but it behaves so, for more information, see issue #18681 - s := BinaryLiteral(d.b).ToString() - if target.GetFlen() == 1 { - switch strings.ToLower(s) { - case "true", "1": - uintValue = 1 - case "false", "0": - uintValue = 0 - default: - uintValue, err = d.convertStringToMysqlBit(sc) - } - } else { - uintValue, err = d.convertStringToMysqlBit(sc) - } case KindInt64: // if input kind is int64 (signed), when trans to bit, we need to treat it as unsigned d.k = KindUint64 diff --git a/types/datum_test.go b/types/datum_test.go index ca6f199629b4e..47468de7cb7f7 100644 --- a/types/datum_test.go +++ b/types/datum_test.go @@ -525,24 +525,37 @@ func prepareCompareDatums() ([]Datum, []Datum) { func TestStringToMysqlBit(t *testing.T) { tests := []struct { - a Datum - out []byte + a Datum + out []byte + flen int + truncated bool }{ - {NewStringDatum("true"), []byte{1}}, - {NewStringDatum("false"), []byte{0}}, - {NewStringDatum("1"), []byte{1}}, - {NewStringDatum("0"), []byte{0}}, - {NewStringDatum("b'1'"), []byte{1}}, - {NewStringDatum("b'0'"), []byte{0}}, + {NewStringDatum("true"), []byte{1}, 1, true}, + {NewStringDatum("true"), []byte{0x74, 0x72, 0x75, 0x65}, 32, false}, + {NewStringDatum("false"), []byte{0x1}, 1, true}, + {NewStringDatum("false"), []byte{0x66, 0x61, 0x6c, 0x73, 0x65}, 40, false}, + {NewStringDatum("1"), []byte{1}, 1, true}, + {NewStringDatum("1"), []byte{0x31}, 8, false}, + {NewStringDatum("0"), []byte{1}, 1, true}, + {NewStringDatum("0"), []byte{0x30}, 8, false}, + {NewStringDatum("b'1'"), []byte{0x62, 0x27, 0x31, 0x27}, 32, false}, + {NewStringDatum("b'0'"), []byte{0x62, 0x27, 0x30, 0x27}, 32, false}, } sc := new(stmtctx.StatementContext) sc.IgnoreTruncate = true - tp := NewFieldType(mysql.TypeBit) - tp.SetFlen(1) for _, tt := range tests { - bin, err := tt.a.convertToMysqlBit(nil, tp) - require.NoError(t, err) - require.Equal(t, tt.out, bin.b) + t.Run(fmt.Sprintf("%s %d %t", tt.a.GetString(), tt.flen, tt.truncated), func(t *testing.T) { + tp := NewFieldType(mysql.TypeBit) + tp.SetFlen(tt.flen) + + bin, err := tt.a.convertToMysqlBit(sc, tp) + if tt.truncated { + require.Contains(t, err.Error(), "Data Too Long") + } else { + require.NoError(t, err) + } + require.Equal(t, tt.out, bin.b) + }) } }