diff --git a/pkg/ddl/ddl_api.go b/pkg/ddl/ddl_api.go index 3df1ab877740f..05fe65ee7ccb6 100644 --- a/pkg/ddl/ddl_api.go +++ b/pkg/ddl/ddl_api.go @@ -64,6 +64,7 @@ import ( "github.com/pingcap/tidb/pkg/util/dbterror" "github.com/pingcap/tidb/pkg/util/domainutil" "github.com/pingcap/tidb/pkg/util/hack" + "github.com/pingcap/tidb/pkg/util/intest" "github.com/pingcap/tidb/pkg/util/logutil" "github.com/pingcap/tidb/pkg/util/mathutil" "github.com/pingcap/tidb/pkg/util/memory" @@ -1031,6 +1032,21 @@ func checkColumnDefaultValue(ctx sessionctx.Context, col *table.Column, value in } } } + if value != nil && col.GetType() == mysql.TypeBit { + v, ok := value.(string) + if !ok { + return hasDefaultValue, value, types.ErrInvalidDefault.GenWithStackByArgs(col.Name.O) + } + + uintVal, err := types.BinaryLiteral(v).ToInt(ctx.GetSessionVars().StmtCtx) + if err != nil { + return hasDefaultValue, value, types.ErrInvalidDefault.GenWithStackByArgs(col.Name.O) + } + intest.Assert(col.GetFlen() > 0 && col.GetFlen() <= 64) + if col.GetFlen() < 64 && uintVal >= 1<<(uint64(col.GetFlen())) { + return hasDefaultValue, value, types.ErrInvalidDefault.GenWithStackByArgs(col.Name.O) + } + } return hasDefaultValue, value, nil } @@ -5283,13 +5299,14 @@ func SetDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu } col.DefaultIsExpr = isSeqExpr } - - if hasDefaultValue, value, err = checkColumnDefaultValue(ctx, col, value); err != nil { - return hasDefaultValue, errors.Trace(err) - } - value, err = convertTimestampDefaultValToUTC(ctx, value, col) - if err != nil { - return hasDefaultValue, errors.Trace(err) + if !col.DefaultIsExpr { + if hasDefaultValue, value, err = checkColumnDefaultValue(ctx, col, value); err != nil { + return hasDefaultValue, errors.Trace(err) + } + value, err = convertTimestampDefaultValToUTC(ctx, value, col) + if err != nil { + return hasDefaultValue, errors.Trace(err) + } } err = setDefaultValueWithBinaryPadding(col, value) if err != nil { diff --git a/pkg/executor/test/writetest/write_test.go b/pkg/executor/test/writetest/write_test.go index 07bd89cb9206f..09dadd30bfa67 100644 --- a/pkg/executor/test/writetest/write_test.go +++ b/pkg/executor/test/writetest/write_test.go @@ -1301,7 +1301,7 @@ func TestIssue18681(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec("use test") createSQL := `drop table if exists load_data_test; - create table load_data_test (a bit(1),b bit(1),c bit(1),d bit(1));` + create table load_data_test (a bit(1),b bit(1),c bit(1),d bit(1),e bit(32),f bit(1));` tk.MustExec(createSQL) tk.MustExec("load data local infile '/tmp/nonexistence.csv' ignore into table load_data_test") ctx := tk.Session().(sessionctx.Context) @@ -1311,7 +1311,7 @@ func TestIssue18681(t *testing.T) { require.NotNil(t, ld) deleteSQL := "delete from load_data_test" - selectSQL := "select bin(a), bin(b), bin(c), bin(d) from load_data_test;" + selectSQL := "select bin(a), bin(b), bin(c), bin(d), bin(e), bin(f) from load_data_test;" ctx.GetSessionVars().StmtCtx.DupKeyAsWarning = true ctx.GetSessionVars().StmtCtx.BadNullAsWarning = true @@ -1322,7 +1322,7 @@ func TestIssue18681(t *testing.T) { }() sc.IgnoreTruncate.Store(false) tests := []testCase{ - {[]byte("true\tfalse\t0\t1\n"), []string{"1|0|0|1"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 0"}, + {[]byte("true\tfalse\t0\t1\tb'1'\tb'1'\n"), []string{"1|1|1|1|1100010001001110011000100100111|1"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 5"}, } checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) require.Equal(t, uint16(0), sc.WarningCount()) diff --git a/pkg/types/datum.go b/pkg/types/datum.go index dfaa7fae9adeb..015f7090ea7c9 100644 --- a/pkg/types/datum.go +++ b/pkg/types/datum.go @@ -1576,38 +1576,13 @@ func (d *Datum) ConvertToMysqlYear(sc *stmtctx.StatementContext, target *FieldTy return ret, errors.Trace(err) } -func (d *Datum) convertStringToMysqlBit(sc *stmtctx.StatementContext) (uint64, error) { - bitStr, err := ParseBitStr(BinaryLiteral(d.b).ToString()) - if err != nil { - // It cannot be converted to bit type, so we need to convert it to int type. - return BinaryLiteral(d.b).ToInt(sc) - } - return bitStr.ToInt(sc) -} - func (d *Datum) convertToMysqlBit(sc *stmtctx.StatementContext, target *FieldType) (Datum, error) { var ret Datum var uintValue uint64 var err error switch d.k { - case KindBytes: + case KindString, KindBytes: uintValue, err = BinaryLiteral(d.b).ToInt(sc) - case KindString: - // For single bit value, we take string like "true", "1" as 1, and "false", "0" as 0, - // this behavior is not documented in MySQL, but it behaves so, for more information, see issue #18681 - s := BinaryLiteral(d.b).ToString() - if target.GetFlen() == 1 { - switch strings.ToLower(s) { - case "true", "1": - uintValue = 1 - case "false", "0": - uintValue = 0 - default: - uintValue, err = d.convertStringToMysqlBit(sc) - } - } else { - uintValue, err = d.convertStringToMysqlBit(sc) - } case KindInt64: // if input kind is int64 (signed), when trans to bit, we need to treat it as unsigned d.k = KindUint64 diff --git a/pkg/types/datum_test.go b/pkg/types/datum_test.go index 4b6a4e576207e..88b231194b15d 100644 --- a/pkg/types/datum_test.go +++ b/pkg/types/datum_test.go @@ -527,24 +527,37 @@ func prepareCompareDatums() ([]Datum, []Datum) { func TestStringToMysqlBit(t *testing.T) { tests := []struct { - a Datum - out []byte + a Datum + out []byte + flen int + truncated bool }{ - {NewStringDatum("true"), []byte{1}}, - {NewStringDatum("false"), []byte{0}}, - {NewStringDatum("1"), []byte{1}}, - {NewStringDatum("0"), []byte{0}}, - {NewStringDatum("b'1'"), []byte{1}}, - {NewStringDatum("b'0'"), []byte{0}}, + {NewStringDatum("true"), []byte{1}, 1, true}, + {NewStringDatum("true"), []byte{0x74, 0x72, 0x75, 0x65}, 32, false}, + {NewStringDatum("false"), []byte{0x1}, 1, true}, + {NewStringDatum("false"), []byte{0x66, 0x61, 0x6c, 0x73, 0x65}, 40, false}, + {NewStringDatum("1"), []byte{1}, 1, true}, + {NewStringDatum("1"), []byte{0x31}, 8, false}, + {NewStringDatum("0"), []byte{1}, 1, true}, + {NewStringDatum("0"), []byte{0x30}, 8, false}, + {NewStringDatum("b'1'"), []byte{0x62, 0x27, 0x31, 0x27}, 32, false}, + {NewStringDatum("b'0'"), []byte{0x62, 0x27, 0x30, 0x27}, 32, false}, } sc := stmtctx.NewStmtCtx() sc.IgnoreTruncate.Store(true) - tp := NewFieldType(mysql.TypeBit) - tp.SetFlen(1) for _, tt := range tests { - bin, err := tt.a.convertToMysqlBit(nil, tp) - require.NoError(t, err) - require.Equal(t, tt.out, bin.b) + t.Run(fmt.Sprintf("%s %d %t", tt.a.GetString(), tt.flen, tt.truncated), func(t *testing.T) { + tp := NewFieldType(mysql.TypeBit) + tp.SetFlen(tt.flen) + + bin, err := tt.a.convertToMysqlBit(sc, tp) + if tt.truncated { + require.Contains(t, err.Error(), "Data Too Long") + } else { + require.NoError(t, err) + } + require.Equal(t, tt.out, bin.b) + }) } } diff --git a/tests/integrationtest/r/ddl/column.result b/tests/integrationtest/r/ddl/column.result index 7d285b5ec4881..e4e173bead33d 100644 --- a/tests/integrationtest/r/ddl/column.result +++ b/tests/integrationtest/r/ddl/column.result @@ -65,3 +65,15 @@ t CREATE TABLE `t` ( `a` decimal(10,0) DEFAULT NULL, `b` decimal(10,0) DEFAULT NULL ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin +drop table if exists t; +create table t(a bit(2) default b'111'); +Error 1067 (42000): Invalid default value for 'a' +create table t(a bit(65) default b'111'); +Error 1439 (42000): Display width out of range for column 'a' (max = 64) +create table t(a bit(64) default b'1111111111111111111111111111111111111111111111111111111111111111'); +drop table t; +create table t(a bit(3) default b'111'); +drop table t; +create table t(a bit(3) default b'000111'); +drop table t; +create table t(a bit(32) default b'1111111111111111111111111111111'); diff --git a/tests/integrationtest/r/table/tables.result b/tests/integrationtest/r/table/tables.result index e09bc5127f4c5..b62f9a0810b62 100644 --- a/tests/integrationtest/r/table/tables.result +++ b/tests/integrationtest/r/table/tables.result @@ -6,3 +6,9 @@ select count(distinct(_tidb_rowid>>48)) from shard_t; count(distinct(_tidb_rowid>>48)) 4 set @@tidb_shard_allocate_step=default; +drop table if exists t; +create table t(a bit(32) default b'1100010001001110011000100100111'); +insert into t values (); +select hex(a) from t; +hex(a) +62273127 diff --git a/tests/integrationtest/t/ddl/column.test b/tests/integrationtest/t/ddl/column.test index 0376192a6ada0..36c093c4dc146 100644 --- a/tests/integrationtest/t/ddl/column.test +++ b/tests/integrationtest/t/ddl/column.test @@ -22,3 +22,17 @@ show create table t2; drop table if exists t; create table t(a decimal(0,0), b decimal(0)); show create table t; + +# TestTooLongDefaultValueForBit +drop table if exists t; +-- error 1067 +create table t(a bit(2) default b'111'); +-- error 1439 +create table t(a bit(65) default b'111'); +create table t(a bit(64) default b'1111111111111111111111111111111111111111111111111111111111111111'); +drop table t; +create table t(a bit(3) default b'111'); +drop table t; +create table t(a bit(3) default b'000111'); +drop table t; +create table t(a bit(32) default b'1111111111111111111111111111111'); \ No newline at end of file diff --git a/tests/integrationtest/t/table/tables.test b/tests/integrationtest/t/table/tables.test index b2896d7a39a2e..771825b0c0f95 100644 --- a/tests/integrationtest/t/table/tables.test +++ b/tests/integrationtest/t/table/tables.test @@ -5,3 +5,9 @@ set @@tidb_shard_allocate_step=3; insert into shard_t values (1), (2), (3), (4), (5), (6), (7), (8), (9), (10), (11); select count(distinct(_tidb_rowid>>48)) from shard_t; set @@tidb_shard_allocate_step=default; + +# TestInsertBitDefaultValue +drop table if exists t; +create table t(a bit(32) default b'1100010001001110011000100100111'); +insert into t values (); +select hex(a) from t; \ No newline at end of file