Skip to content

Commit

Permalink
*: Fix add index after add column with default value (#3510) (#3513)
Browse files Browse the repository at this point in the history
  • Loading branch information
zimulala authored Jun 20, 2017
1 parent 14d4abd commit 2e0e945
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 31 deletions.
12 changes: 12 additions & 0 deletions ddl/ddl_db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ func (s *testDBSuite) TestMySQLErrorCode(c *C) {
s.testErrorCode(c, sql, tmysql.ErrWrongTableName)
}

func (s *testDBSuite) TestAddIndexAfterAddColumn(c *C) {
defer testleak.AfterTest(c)()
s.tk = testkit.NewTestKit(c, s.store)
s.tk.MustExec("use " + s.schemaName)

s.tk.MustExec("create table test_add_index_after_add_col(a int, b int not null default '0')")
s.tk.MustExec("insert into test_add_index_after_add_col values(1, 2),(2,2)")
s.tk.MustExec("alter table test_add_index_after_add_col add column c int not null default '0'")
sql := "alter table test_add_index_after_add_col add unique index cc(c) "
s.testErrorCode(c, sql, tmysql.ErrDupEntry)
}

func (s *testDBSuite) TestIndex(c *C) {
defer testleak.AfterTest(c)()
s.tk = testkit.NewTestKit(c, s.store)
Expand Down
14 changes: 13 additions & 1 deletion ddl/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,9 @@ func (d *ddl) fetchRowColVals(txn kv.Transaction, t table.Table, taskOpInfo *ind
}

cols := t.Cols()
ctx := d.newContext()
idxInfo := taskOpInfo.tblIndex.Meta()
defaultVals := make([]types.Datum, len(cols))
for i, idxRecord := range idxRecords {
rowMap, err := tablecodec.DecodeRow(rawRecords[i], taskOpInfo.colMap)
if err != nil {
Expand All @@ -423,7 +425,17 @@ func (d *ddl) fetchRowColVals(txn kv.Transaction, t table.Table, taskOpInfo *ind
idxVal := make([]types.Datum, 0, len(idxInfo.Columns))
for _, v := range idxInfo.Columns {
col := cols[v.Offset]
idxVal = append(idxVal, rowMap[col.ID])
idxColumnVal := rowMap[col.ID]
if _, ok := rowMap[col.ID]; ok {
idxVal = append(idxVal, idxColumnVal)
continue
}
idxColumnVal, ret.err = tables.GetColDefaultValue(ctx, col, defaultVals)
if ret.err != nil {
ret.err = errors.Trace(ret.err)
return nil, ret
}
idxVal = append(idxVal, idxColumnVal)
}
idxRecord.vals = idxVal
}
Expand Down
68 changes: 38 additions & 30 deletions table/tables/tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,33 +450,26 @@ func (t *Table) RowWithCols(ctx context.Context, h int64, cols []*table.Column)
}
colTps[col.ID] = &col.FieldType
}
row, err := tablecodec.DecodeRow(value, colTps)
rowMap, err := tablecodec.DecodeRow(value, colTps)
if err != nil {
return nil, errors.Trace(err)
}
defaultVals := make([]types.Datum, len(cols))
for i, col := range cols {
if col == nil {
continue
}
if col.IsPKHandleColumn(t.meta) {
continue
}
ri, ok := row[col.ID]
ri, ok := rowMap[col.ID]
if ok {
v[i] = ri
continue
}

if col.OriginDefaultValue != nil && col.State == model.StatePublic {
ri, err = table.GetColOriginDefaultValue(ctx, col.ToInfo())
if err != nil {
return nil, errors.Trace(err)
}
v[i] = ri
continue
}
if mysql.HasNotNullFlag(col.Flag) {
return nil, errors.New("Miss column")
v[i], err = GetColDefaultValue(ctx, col, defaultVals)
if err != nil {
return nil, errors.Trace(err)
}
}
return v, nil
Expand Down Expand Up @@ -624,29 +617,21 @@ func (t *Table) IterRecords(ctx context.Context, startKey kv.Key, cols []*table.
}
data := make([]types.Datum, len(cols))
for _, col := range cols {
if col.IsPKHandleColumn(t.Meta()) {
data[col.Offset] = types.NewIntDatum(handle)
if col.IsPKHandleColumn(t.meta) {
if mysql.HasUnsignedFlag(col.Flag) {
data[col.Offset].SetUint64(uint64(handle))
} else {
data[col.Offset].SetInt64(handle)
}
continue
}
if _, ok := rowMap[col.ID]; ok {
data[col.Offset] = rowMap[col.ID]
continue
}
if col.OriginDefaultValue == nil && mysql.HasNotNullFlag(col.Flag) {
return errors.New("Miss column")
}
if col.State != model.StatePublic {
continue
}
if defaultVals[col.Offset].IsNull() {
d, err := table.GetColOriginDefaultValue(ctx, col.ToInfo())
if err != nil {
return errors.Trace(err)
}
data[col.Offset] = d
defaultVals[col.Offset] = d
} else {
data[col.Offset] = defaultVals[col.Offset]
data[col.Offset], err = GetColDefaultValue(ctx, col, defaultVals)
if err != nil {
return errors.Trace(err)
}
}
more, err := fn(handle, data, cols)
Expand All @@ -664,6 +649,29 @@ func (t *Table) IterRecords(ctx context.Context, startKey kv.Key, cols []*table.
return nil
}

// GetColDefaultValue gets a column default value.
// The defaultVals is used to avoid calculating the default value multiple times.
func GetColDefaultValue(ctx context.Context, col *table.Column, defaultVals []types.Datum) (
colVal types.Datum, err error) {
if col.OriginDefaultValue == nil && mysql.HasNotNullFlag(col.Flag) {
return colVal, errors.New("Miss column")
}
if col.State != model.StatePublic {
return colVal, nil
}
if defaultVals[col.Offset].IsNull() {
colVal, err = table.GetColOriginDefaultValue(ctx, col.ToInfo())
if err != nil {
return colVal, errors.Trace(err)
}
defaultVals[col.Offset] = colVal
} else {
colVal = defaultVals[col.Offset]
}

return colVal, nil
}

// AllocAutoID implements table.Table AllocAutoID interface.
func (t *Table) AllocAutoID() (int64, error) {
return t.alloc.Alloc(t.ID)
Expand Down

0 comments on commit 2e0e945

Please sign in to comment.