diff --git a/executor/builder.go b/executor/builder.go index 0079d0264d884..ae179d9f1b867 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -538,15 +538,15 @@ func (b *executorBuilder) buildInsert(v *plannercore.Insert) Executor { baseExec.initCap = chunk.ZeroCapacity ivs := &InsertValues{ - baseExecutor: baseExec, - Table: v.Table, - Columns: v.Columns, - Lists: v.Lists, - SetList: v.SetList, - GenColumns: v.GenCols.Columns, - GenExprs: v.GenCols.Exprs, - needFillDefaultValues: v.NeedFillDefaultValue, - SelectExec: selectExec, + baseExecutor: baseExec, + Table: v.Table, + Columns: v.Columns, + Lists: v.Lists, + SetList: v.SetList, + GenColumns: v.GenCols.Columns, + GenExprs: v.GenCols.Exprs, + hasRefCols: v.NeedFillDefaultValue, + SelectExec: selectExec, } if v.IsReplace { diff --git a/executor/delete.go b/executor/delete.go index 59c9d1eee5fc7..026a1975666d8 100644 --- a/executor/delete.go +++ b/executor/delete.go @@ -39,20 +39,11 @@ type DeleteExec struct { // `delete from t as t1, t as t2`, the same table has two alias, we have to identify a table // by its alias instead of ID. tblMap map[int64][]*ast.TableName - - finished bool } // Next implements the Executor Next interface. func (e *DeleteExec) Next(ctx context.Context, chk *chunk.Chunk) error { chk.Reset() - if e.finished { - return nil - } - defer func() { - e.finished = true - }() - if e.IsMultiTable { return errors.Trace(e.deleteMultiTablesByChunk(ctx)) } diff --git a/executor/executor_test.go b/executor/executor_test.go index 3fe8da019b82e..8ccdec56307b0 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -1477,6 +1477,12 @@ func (s *testSuite) TestGeneratedColumnRead(c *C) { result = tk.MustQuery(`SELECT * FROM test_gc_read WHERE d = 12`) result.Check(testkit.Rows(`3 4 7 12`)) + tk.MustExec(`INSERT INTO test_gc_read set a = 4, b = d + 1`) + result = tk.MustQuery(`SELECT * FROM test_gc_read ORDER BY a`) + result.Check(testkit.Rows(`0 `, `1 2 3 2`, `3 4 7 12`, + `4 `, `8 8 16 64`)) + tk.MustExec(`DELETE FROM test_gc_read where a = 4`) + // Test on-conditions on virtual/stored generated columns. tk.MustExec(`CREATE TABLE test_gc_help(a int primary key, b int, c int, d int)`) tk.MustExec(`INSERT INTO test_gc_help(a, b, c, d) SELECT * FROM test_gc_read`) diff --git a/executor/insert.go b/executor/insert.go index 6828a4d4e4990..932a2c09ce326 100644 --- a/executor/insert.go +++ b/executor/insert.go @@ -31,7 +31,6 @@ type InsertExec struct { *InsertValues OnDuplicate []*expression.Assignment Priority mysql.PriorityEnum - finished bool } func (e *InsertExec) exec(rows [][]types.Datum) error { @@ -68,7 +67,6 @@ func (e *InsertExec) exec(rows [][]types.Datum) error { } } } - e.finished = true return nil } @@ -131,9 +129,6 @@ func (e *InsertExec) batchUpdateDupRows(newRows [][]types.Datum) error { // Next implements Exec Next interface. func (e *InsertExec) Next(ctx context.Context, chk *chunk.Chunk) error { chk.Reset() - if e.finished { - return nil - } cols, err := e.getColumns(e.Table.Cols()) if err != nil { return errors.Trace(err) diff --git a/executor/insert_common.go b/executor/insert_common.go index 628bead17cfbe..2517420943feb 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -34,11 +34,11 @@ type InsertValues struct { baseExecutor batchChecker - rowCount uint64 - maxRowsInBatch uint64 - lastInsertID uint64 - needFillDefaultValues bool - hasExtraHandle bool + rowCount uint64 + maxRowsInBatch uint64 + lastInsertID uint64 + hasRefCols bool + hasExtraHandle bool SelectExec Executor @@ -135,7 +135,7 @@ func (e *InsertValues) lazilyInitColDefaultValBuf() (ok bool) { return false } -func (e *InsertValues) fillValueList() error { +func (e *InsertValues) processSetList() error { if len(e.SetList) > 0 { if len(e.Lists) > 0 { return errors.Errorf("INSERT INTO %s: set type should not use values", e.Table) @@ -149,19 +149,21 @@ func (e *InsertValues) fillValueList() error { return nil } +// insertRows processes `insert|replace into values ()` or `insert|replace into set x=y` func (e *InsertValues) insertRows(cols []*table.Column, exec func(rows [][]types.Datum) error) (err error) { - // process `insert|replace ... set x=y...` - if err = e.fillValueList(); err != nil { + // For `insert|replace into set x=y`, process the set list here. + if err = e.processSetList(); err != nil { return errors.Trace(err) } - rows := make([][]types.Datum, len(e.Lists)) + rows := make([][]types.Datum, 0, len(e.Lists)) for i, list := range e.Lists { e.rowCount++ - rows[i], err = e.getRow(cols, list, i) + row, err := e.evalRow(cols, list, i) if err != nil { return errors.Trace(err) } + rows = append(rows, row) } return errors.Trace(exec(rows)) } @@ -185,9 +187,9 @@ func (e *InsertValues) handleErr(col *table.Column, val *types.Datum, rowIdx int return e.filterErr(err) } -// getRow eval the insert statement. Because the value of column may calculated based on other column, -// it use fillDefaultValues to init the empty row before eval expressions when needFillDefaultValues is true. -func (e *InsertValues) getRow(cols []*table.Column, list []expression.Expression, rowIdx int) ([]types.Datum, error) { +// evalRow evaluates a to-be-inserted row. The value of the column may base on another column, +// so we use setValueForRefColumn to fill the empty row some default values when needFillDefaultValues is true. +func (e *InsertValues) evalRow(cols []*table.Column, list []expression.Expression, rowIdx int) ([]types.Datum, error) { rowLen := len(e.Table.Cols()) if e.hasExtraHandle { rowLen++ @@ -195,8 +197,9 @@ func (e *InsertValues) getRow(cols []*table.Column, list []expression.Expression row := make([]types.Datum, rowLen) hasValue := make([]bool, rowLen) - if e.needFillDefaultValues { - if err := e.fillDefaultValues(row, hasValue); err != nil { + // For statements like `insert into t set a = b + 1`. + if e.hasRefCols { + if err := e.setValueForRefColumn(row, hasValue); err != nil { return nil, errors.Trace(err) } } @@ -215,34 +218,36 @@ func (e *InsertValues) getRow(cols []*table.Column, list []expression.Expression row[offset], hasValue[offset] = val1, true } - return e.fillGenColData(cols, len(list), hasValue, row) + return e.fillRow(row, hasValue) } -// fillDefaultValues fills a row followed by these rules: +// setValueForRefColumn set some default values for the row to eval the row value with other columns, +// it follows these rules: // 1. for nullable and no default value column, use NULL. // 2. for nullable and have default value column, use it's default value. // 3. for not null column, use zero value even in strict mode. // 4. for auto_increment column, use zero value. // 5. for generated column, use NULL. -func (e *InsertValues) fillDefaultValues(row []types.Datum, hasValue []bool) error { +func (e *InsertValues) setValueForRefColumn(row []types.Datum, hasValue []bool) error { for i, c := range e.Table.Cols() { - var err error - if c.IsGenerated() { - continue - } else if mysql.HasAutoIncrementFlag(c.Flag) { - row[i] = table.GetZeroValue(c.ToInfo()) - } else { - row[i], err = e.getColDefaultValue(i, c) - hasValue[c.Offset] = true - if table.ErrNoDefaultValue.Equal(err) { - row[i] = table.GetZeroValue(c.ToInfo()) - hasValue[c.Offset] = false - } else if e.filterErr(err) != nil { - return errors.Trace(err) + d, err := e.getColDefaultValue(i, c) + if err == nil { + row[i] = d + if !mysql.HasAutoIncrementFlag(c.Flag) { + // It is an interesting behavior in MySQL. + // If the value of auto ID is not explicit, MySQL use 0 value for auto ID when it is + // evaluated by another column, but it should be used once only. + // When we fill it as an auto ID column, it should be set as it used to be. + // So just keep `hasValue` false for auto ID, and the others set true. + hasValue[c.Offset] = true } + } else if table.ErrNoDefaultValue.Equal(err) { + row[i] = table.GetZeroValue(c.ToInfo()) + hasValue[c.Offset] = false + } else if e.filterErr(err) != nil { + return errors.Trace(err) } } - return nil } @@ -270,7 +275,7 @@ func (e *InsertValues) insertRowsFromSelect(ctx context.Context, cols []*table.C for innerChunkRow := iter.Begin(); innerChunkRow != iter.End(); innerChunkRow = iter.Next() { innerRow := types.CopyRow(innerChunkRow.GetDatumRow(fields)) e.rowCount++ - row, err := e.fillRowData(cols, innerRow) + row, err := e.getRow(cols, innerRow) if err != nil { return errors.Trace(err) } @@ -297,7 +302,10 @@ func (e *InsertValues) insertRowsFromSelect(ctx context.Context, cols []*table.C return nil } -func (e *InsertValues) fillRowData(cols []*table.Column, vals []types.Datum) ([]types.Datum, error) { +// getRow gets the row which from `insert into select from` or `load data`. +// The input values from these two statements are datums instead of +// expressions which are used in `insert into set x=y`. +func (e *InsertValues) getRow(cols []*table.Column, vals []types.Datum) ([]types.Datum, error) { row := make([]types.Datum, len(e.Table.Cols())) hasValue := make([]bool, len(e.Table.Cols())) for i, v := range vals { @@ -311,32 +319,7 @@ func (e *InsertValues) fillRowData(cols []*table.Column, vals []types.Datum) ([] hasValue[offset] = true } - return e.fillGenColData(cols, len(vals), hasValue, row) -} - -func (e *InsertValues) fillGenColData(cols []*table.Column, valLen int, hasValue []bool, row []types.Datum) ([]types.Datum, error) { - err := e.initDefaultValues(row, hasValue) - if err != nil { - return nil, errors.Trace(err) - } - for i, expr := range e.GenExprs { - var val types.Datum - val, err = expr.Eval(chunk.MutRowFromDatums(row).ToRow()) - if e.filterErr(err) != nil { - return nil, errors.Trace(err) - } - val, err = table.CastValue(e.ctx, val, cols[valLen+i].ToInfo()) - if err != nil { - return nil, errors.Trace(err) - } - offset := cols[valLen+i].Offset - row[offset] = val - } - - if err = table.CheckNotNull(e.Table.Cols(), row); err != nil { - return nil, errors.Trace(err) - } - return row, nil + return e.fillRow(row, hasValue) } func (e *InsertValues) filterErr(err error) error { @@ -351,6 +334,7 @@ func (e *InsertValues) filterErr(err error) error { return nil } +// getColDefaultValue gets the column default value. func (e *InsertValues) getColDefaultValue(idx int, col *table.Column) (d types.Datum, err error) { if e.colDefaultVals != nil && e.colDefaultVals[idx].valid { return e.colDefaultVals[idx].val, nil @@ -368,81 +352,100 @@ func (e *InsertValues) getColDefaultValue(idx int, col *table.Column) (d types.D return defaultVal, nil } -// initDefaultValues fills generated columns, auto_increment column and empty column. +// fillColValue fills the column value if it is not set in the insert statement. +func (e *InsertValues) fillColValue(datum types.Datum, idx int, column *table.Column, hasValue bool) (types.Datum, + error) { + if mysql.HasAutoIncrementFlag(column.Flag) { + d, err := e.adjustAutoIncrementDatum(datum, hasValue, column) + if err != nil { + return types.Datum{}, errors.Trace(err) + } + return d, nil + } + if !hasValue { + d, err := e.getColDefaultValue(idx, column) + if e.filterErr(err) != nil { + return types.Datum{}, errors.Trace(err) + } + return d, nil + } + return datum, nil +} + +// fillRow fills generated columns, auto_increment column and empty column. // For NOT NULL column, it will return error or use zero value based on sql_mode. -func (e *InsertValues) initDefaultValues(row []types.Datum, hasValue []bool) error { +func (e *InsertValues) fillRow(row []types.Datum, hasValue []bool) ([]types.Datum, error) { + gIdx := 0 for i, c := range e.Table.Cols() { - if mysql.HasAutoIncrementFlag(c.Flag) || c.IsGenerated() { - // Just leave generated column as null. It will be calculated later - // but before we check whether the column can be null or not. - if !hasValue[i] { - row[i].SetNull() - } - // Adjust the value if this column has auto increment flag. - if mysql.HasAutoIncrementFlag(c.Flag) { - if err := e.adjustAutoIncrementDatum(row, i, c); err != nil { - return errors.Trace(err) - } + var err error + // Get the default value for all no value columns, the auto increment column is different from the others. + row[i], err = e.fillColValue(row[i], i, c, hasValue[i]) + if err != nil { + return nil, errors.Trace(err) + } + + // Evaluate the generated columns. + if c.IsGenerated() { + var val types.Datum + val, err = e.GenExprs[gIdx].Eval(chunk.MutRowFromDatums(row).ToRow()) + gIdx++ + if e.filterErr(err) != nil { + return nil, errors.Trace(err) } - } else { - if !hasValue[i] || (mysql.HasNotNullFlag(c.Flag) && row[i]. - IsNull() && e.ctx.GetSessionVars().StmtCtx.BadNullAsWarning) { - var err error - row[i], err = e.getColDefaultValue(i, c) - if e.filterErr(err) != nil { - return errors.Trace(err) - } + row[i], err = table.CastValue(e.ctx, val, c.ToInfo()) + if err != nil { + return nil, errors.Trace(err) } } + + // Handle the bad null error. + if row[i], err = c.HandleBadNull(row[i], e.ctx.GetSessionVars().StmtCtx); err != nil { + return nil, errors.Trace(err) + } } - return nil + return row, nil } -func (e *InsertValues) adjustAutoIncrementDatum(row []types.Datum, i int, c *table.Column) error { +func (e *InsertValues) adjustAutoIncrementDatum(d types.Datum, hasValue bool, c *table.Column) (types.Datum, error) { retryInfo := e.ctx.GetSessionVars().RetryInfo if retryInfo.Retrying { id, err := retryInfo.GetCurrAutoIncrementID() if err != nil { - return errors.Trace(err) + return types.Datum{}, errors.Trace(err) } - if mysql.HasUnsignedFlag(c.Flag) { - row[i].SetUint64(uint64(id)) - } else { - row[i].SetInt64(id) - } - return nil + d.SetAutoID(id, c.Flag) + return d, nil } var err error var recordID int64 - if !row[i].IsNull() { - recordID, err = row[i].ToInt64(e.ctx.GetSessionVars().StmtCtx) + if !hasValue { + d.SetNull() + } + if !d.IsNull() { + recordID, err = d.ToInt64(e.ctx.GetSessionVars().StmtCtx) if e.filterErr(err) != nil { - return errors.Trace(err) + return types.Datum{}, errors.Trace(err) } } // Use the value if it's not null and not 0. if recordID != 0 { err = e.Table.RebaseAutoID(e.ctx, recordID, true) if err != nil { - return errors.Trace(err) + return types.Datum{}, errors.Trace(err) } e.ctx.GetSessionVars().InsertID = uint64(recordID) - if mysql.HasUnsignedFlag(c.Flag) { - row[i].SetUint64(uint64(recordID)) - } else { - row[i].SetInt64(recordID) - } retryInfo.AddAutoIncrementID(recordID) - return nil + d.SetAutoID(recordID, c.Flag) + return d, nil } // Change NULL to auto id. // Change value 0 to auto id, if NoAutoValueOnZero SQL mode is not set. - if row[i].IsNull() || e.ctx.GetSessionVars().SQLMode&mysql.ModeNoAutoValueOnZero == 0 { + if d.IsNull() || e.ctx.GetSessionVars().SQLMode&mysql.ModeNoAutoValueOnZero == 0 { recordID, err = e.Table.AllocAutoID(e.ctx) if e.filterErr(err) != nil { - return errors.Trace(err) + return types.Datum{}, errors.Trace(err) } // It's compatible with mysql. So it sets last insert id to the first row. if e.rowCount == 1 { @@ -450,20 +453,15 @@ func (e *InsertValues) adjustAutoIncrementDatum(row []types.Datum, i int, c *tab } } - if mysql.HasUnsignedFlag(c.Flag) { - row[i].SetUint64(uint64(recordID)) - } else { - row[i].SetInt64(recordID) - } + d.SetAutoID(recordID, c.Flag) retryInfo.AddAutoIncrementID(recordID) - // the value of row[i] is adjusted by autoid, so we need to cast it again. - casted, err := table.CastValue(e.ctx, row[i], c.ToInfo()) + // the value of d is adjusted by auto ID, so we need to cast it again. + casted, err := table.CastValue(e.ctx, d, c.ToInfo()) if err != nil { - return errors.Trace(err) + return types.Datum{}, errors.Trace(err) } - row[i] = casted - return nil + return casted, nil } func (e *InsertValues) handleWarning(err error, logInfo string) { diff --git a/executor/load_data.go b/executor/load_data.go index 652ad68d1aa09..d6ee3724a3a74 100644 --- a/executor/load_data.go +++ b/executor/load_data.go @@ -274,7 +274,7 @@ func (e *LoadDataInfo) colsToRow(cols []field) []types.Datum { e.row[i].SetString(string(cols[i].str)) } } - row, err := e.fillRowData(e.columns, e.row) + row, err := e.getRow(e.columns, e.row) if err != nil { e.handleWarning(err, fmt.Sprintf("Load Data: insert data:%v failed:%v", e.row, errors.ErrorStack(err))) diff --git a/executor/replace.go b/executor/replace.go index 6f416c9ea0f9b..dbe79cfde4817 100644 --- a/executor/replace.go +++ b/executor/replace.go @@ -27,7 +27,6 @@ import ( type ReplaceExec struct { *InsertValues Priority int - finished bool } // Close implements the Executor Close interface. @@ -173,16 +172,12 @@ func (e *ReplaceExec) exec(newRows [][]types.Datum) error { return errors.Trace(err) } } - e.finished = true return nil } // Next implements the Executor Next interface. func (e *ReplaceExec) Next(ctx context.Context, chk *chunk.Chunk) error { chk.Reset() - if e.finished { - return nil - } cols, err := e.getColumns(e.Table.Cols()) if err != nil { return errors.Trace(err) diff --git a/executor/write.go b/executor/write.go index dc4c993182720..de34f67bce23b 100644 --- a/executor/write.go +++ b/executor/write.go @@ -68,17 +68,11 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu } } - // 2. Check null. + // 2. Handle the bad null error. for i, col := range t.Cols() { - if err := col.CheckNotNull(newData[i]); err != nil { - if sc.BadNullAsWarning { - newData[i], err = table.GetColDefaultValue(ctx, col.ToInfo()) - if err != nil { - return false, false, 0, errors.Trace(err) - } - } else { - return false, false, 0, errors.Trace(err) - } + var err error + if newData[i], err = col.HandleBadNull(newData[i], sc); err != nil { + return false, false, 0, errors.Trace(err) } } diff --git a/executor/write_test.go b/executor/write_test.go index 033798c754a82..98a91cf58d455 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -635,11 +635,16 @@ commit;` tk.MustQuery(`SELECT * FROM t1;`).Check(testkit.Rows("1")) testSQL = `DROP TABLE IF EXISTS t1; - CREATE TABLE t1 (f1 INT PRIMARY KEY, f2 INT UNIQUE); + CREATE TABLE t1 (f1 INT PRIMARY KEY, f2 INT NOT NULL UNIQUE); INSERT t1 VALUES (1, 1);` tk.MustExec(testSQL) tk.MustExec(`INSERT t1 VALUES (1, 1), (1, 1) ON DUPLICATE KEY UPDATE f1 = 2, f2 = 2;`) tk.MustQuery(`SELECT * FROM t1 order by f1;`).Check(testkit.Rows("1 1", "2 2")) + _, err := tk.Exec(`INSERT t1 VALUES (1, 1) ON DUPLICATE KEY UPDATE f2 = null;`) + c.Assert(err, NotNil) + tk.MustExec(`INSERT IGNORE t1 VALUES (1, 1) ON DUPLICATE KEY UPDATE f2 = null;`) + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1048 Column 'f2' cannot be null")) + tk.MustQuery(`SELECT * FROM t1 order by f1;`).Check(testkit.Rows("1 0", "2 2")) } func (s *testSuite) TestInsertIgnoreOnDup(c *C) { @@ -1056,6 +1061,15 @@ func (s *testSuite) TestUpdate(c *C) { tk.MustExec("update (select * from t) t set c1 = 1111111") + // test update ignore for bad null error + tk.MustExec("drop table if exists t;") + tk.MustExec(`create table t (i int not null default 10)`) + tk.MustExec("insert into t values (1)") + tk.MustExec("update ignore t set i = null;") + r = tk.MustQuery("SHOW WARNINGS;") + r.Check(testkit.Rows("Warning 1048 Column 'i' cannot be null")) + tk.MustQuery("select * from t").Check(testkit.Rows("0")) + // issue 7237, update subquery table should be forbidden tk.MustExec("drop table t") tk.MustExec("create table t (k int, v int)") diff --git a/table/column.go b/table/column.go index c2b408943f5c5..8e6bb6925a13d 100644 --- a/table/column.go +++ b/table/column.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" "github.com/pingcap/tidb/util/charset" @@ -291,6 +292,19 @@ func (c *Column) CheckNotNull(data types.Datum) error { return nil } +// HandleBadNull handles the bad null error. +// If BadNullAsWarning is true, it will append the error as a warning, else return the error. +func (c *Column) HandleBadNull(d types.Datum, sc *stmtctx.StatementContext) (types.Datum, error) { + if err := c.CheckNotNull(d); err != nil { + if sc.BadNullAsWarning { + sc.AppendWarning(err) + return GetZeroValue(c.ToInfo()), nil + } + return types.Datum{}, errors.Trace(err) + } + return d, nil +} + // IsPKHandleColumn checks if the column is primary key handle column. func (c *Column) IsPKHandleColumn(tbInfo *model.TableInfo) bool { return mysql.HasPriKeyFlag(c.Flag) && tbInfo.PKIsHandle @@ -348,6 +362,9 @@ func getColDefaultValueFromNil(ctx sessionctx.Context, col *model.ColumnInfo) (t } if mysql.HasAutoIncrementFlag(col.Flag) { // Auto increment column doesn't has default value and we should not return error. + return GetZeroValue(col), nil + } + if col.IsGenerated() { return types.Datum{}, nil } sc := ctx.GetSessionVars().StmtCtx diff --git a/table/column_test.go b/table/column_test.go index bf17187a4d9ab..4bf3bc104c0c3 100644 --- a/table/column_test.go +++ b/table/column_test.go @@ -338,7 +338,7 @@ func (t *testTableSuite) TestGetDefaultValue(c *C) { }, }, true, - types.Datum{}, + types.NewIntDatum(0), nil, }, } diff --git a/types/datum.go b/types/datum.go index 34add19204498..959231e633b0f 100644 --- a/types/datum.go +++ b/types/datum.go @@ -323,6 +323,15 @@ func (d *Datum) GetRaw() []byte { return d.b } +// SetAutoID set the auto increment ID according to its int flag. +func (d *Datum) SetAutoID(id int64, flag uint) { + if mysql.HasUnsignedFlag(flag) { + d.SetUint64(uint64(id)) + } else { + d.SetInt64(id) + } +} + // GetValue gets the value of the datum of any kind. func (d *Datum) GetValue() interface{} { switch d.k {