diff --git a/ddl/ddl_db_test.go b/ddl/ddl_db_test.go index 428791186ecda..ef7929f188f44 100644 --- a/ddl/ddl_db_test.go +++ b/ddl/ddl_db_test.go @@ -157,6 +157,18 @@ func (s *testDBSuite) TestMySQLErrorCode(c *C) { s.testErrorCode(c, sql, tmysql.ErrWrongTableName) } +func (s *testDBSuite) TestAddIndexAfterAddColumn(c *C) { + defer testleak.AfterTest(c)() + s.tk = testkit.NewTestKit(c, s.store) + s.tk.MustExec("use " + s.schemaName) + + s.tk.MustExec("create table test_add_index_after_add_col(a int, b int not null default '0')") + s.tk.MustExec("insert into test_add_index_after_add_col values(1, 2),(2,2)") + s.tk.MustExec("alter table test_add_index_after_add_col add column c int not null default '0'") + sql := "alter table test_add_index_after_add_col add unique index cc(c) " + s.testErrorCode(c, sql, tmysql.ErrDupEntry) +} + func (s *testDBSuite) TestIndex(c *C) { defer testleak.AfterTest(c)() s.tk = testkit.NewTestKit(c, s.store) diff --git a/ddl/index.go b/ddl/index.go index 380a642113851..7069e5a2e8f27 100644 --- a/ddl/index.go +++ b/ddl/index.go @@ -421,7 +421,9 @@ func (d *ddl) fetchRowColVals(txn kv.Transaction, t table.Table, taskOpInfo *ind } cols := t.Cols() + ctx := d.newContext() idxInfo := taskOpInfo.tblIndex.Meta() + defaultVals := make([]types.Datum, len(cols)) for i, idxRecord := range idxRecords { rowMap, err := tablecodec.DecodeRow(rawRecords[i], taskOpInfo.colMap, time.UTC) if err != nil { @@ -431,7 +433,17 @@ func (d *ddl) fetchRowColVals(txn kv.Transaction, t table.Table, taskOpInfo *ind idxVal := make([]types.Datum, 0, len(idxInfo.Columns)) for _, v := range idxInfo.Columns { col := cols[v.Offset] - idxVal = append(idxVal, rowMap[col.ID]) + idxColumnVal := rowMap[col.ID] + if _, ok := rowMap[col.ID]; ok { + idxVal = append(idxVal, idxColumnVal) + continue + } + idxColumnVal, ret.err = tables.GetColDefaultValue(ctx, col, defaultVals) + if ret.err != nil { + ret.err = errors.Trace(ret.err) + return nil, ret + } + idxVal = append(idxVal, idxColumnVal) } idxRecord.vals = idxVal } diff --git a/table/tables/tables.go b/table/tables/tables.go index 715e6800416a6..cea22320b3402 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -459,10 +459,11 @@ func (t *Table) RowWithCols(ctx context.Context, h int64, cols []*table.Column) } colTps[col.ID] = &col.FieldType } - row, err := tablecodec.DecodeRow(value, colTps, ctx.GetSessionVars().GetTimeZone()) + rowMap, err := tablecodec.DecodeRow(value, colTps, ctx.GetSessionVars().GetTimeZone()) if err != nil { return nil, errors.Trace(err) } + defaultVals := make([]types.Datum, len(cols)) for i, col := range cols { if col == nil { continue @@ -470,22 +471,14 @@ func (t *Table) RowWithCols(ctx context.Context, h int64, cols []*table.Column) if col.IsPKHandleColumn(t.meta) { continue } - ri, ok := row[col.ID] + ri, ok := rowMap[col.ID] if ok { v[i] = ri continue } - - if col.OriginDefaultValue != nil && col.State == model.StatePublic { - ri, err = table.GetColOriginDefaultValue(ctx, col.ToInfo()) - if err != nil { - return nil, errors.Trace(err) - } - v[i] = ri - continue - } - if mysql.HasNotNullFlag(col.Flag) { - return nil, errors.New("Miss column") + v[i], err = GetColDefaultValue(ctx, col, defaultVals) + if err != nil { + return nil, errors.Trace(err) } } return v, nil @@ -633,29 +626,21 @@ func (t *Table) IterRecords(ctx context.Context, startKey kv.Key, cols []*table. } data := make([]types.Datum, len(cols)) for _, col := range cols { - if col.IsPKHandleColumn(t.Meta()) { - data[col.Offset] = types.NewIntDatum(handle) + if col.IsPKHandleColumn(t.meta) { + if mysql.HasUnsignedFlag(col.Flag) { + data[col.Offset].SetUint64(uint64(handle)) + } else { + data[col.Offset].SetInt64(handle) + } continue } if _, ok := rowMap[col.ID]; ok { data[col.Offset] = rowMap[col.ID] continue } - if col.OriginDefaultValue == nil && mysql.HasNotNullFlag(col.Flag) { - return errors.New("Miss column") - } - if col.State != model.StatePublic { - continue - } - if defaultVals[col.Offset].IsNull() { - d, err := table.GetColOriginDefaultValue(ctx, col.ToInfo()) - if err != nil { - return errors.Trace(err) - } - data[col.Offset] = d - defaultVals[col.Offset] = d - } else { - data[col.Offset] = defaultVals[col.Offset] + data[col.Offset], err = GetColDefaultValue(ctx, col, defaultVals) + if err != nil { + return errors.Trace(err) } } more, err := fn(handle, data, cols) @@ -673,6 +658,29 @@ func (t *Table) IterRecords(ctx context.Context, startKey kv.Key, cols []*table. return nil } +// GetColDefaultValue gets a column default value. +// The defaultVals is used to avoid calculating the default value multiple times. +func GetColDefaultValue(ctx context.Context, col *table.Column, defaultVals []types.Datum) ( + colVal types.Datum, err error) { + if col.OriginDefaultValue == nil && mysql.HasNotNullFlag(col.Flag) { + return colVal, errors.New("Miss column") + } + if col.State != model.StatePublic { + return colVal, nil + } + if defaultVals[col.Offset].IsNull() { + colVal, err = table.GetColOriginDefaultValue(ctx, col.ToInfo()) + if err != nil { + return colVal, errors.Trace(err) + } + defaultVals[col.Offset] = colVal + } else { + colVal = defaultVals[col.Offset] + } + + return colVal, nil +} + // AllocAutoID implements table.Table AllocAutoID interface. func (t *Table) AllocAutoID() (int64, error) { return t.alloc.Alloc(t.ID)