diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 5a9dfb1fb1a69..478f1ccad23eb 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -1713,13 +1713,16 @@ func (d *ddl) CreateIndex(ctx sessionctx.Context, ti ast.Ident, unique bool, ind return errors.Trace(err) } -func buildFKInfo(fkName model.CIStr, keys []*ast.IndexColName, refer *ast.ReferenceDef) (*model.FKInfo, error) { +func buildFKInfo(fkName model.CIStr, keys []*ast.IndexColName, refer *ast.ReferenceDef, cols []*table.Column) (*model.FKInfo, error) { var fkInfo model.FKInfo fkInfo.Name = fkName fkInfo.RefTable = refer.Table.Name fkInfo.Cols = make([]model.CIStr, len(keys)) for i, key := range keys { + if table.FindCol(cols, key.Column.Name.O) == nil { + return nil, errKeyColumnDoesNotExits.Gen("key column %s doesn't exist in table", key.Column.Name) + } fkInfo.Cols[i] = key.Column.Name } @@ -1747,7 +1750,7 @@ func (d *ddl) CreateForeignKey(ctx sessionctx.Context, ti ast.Ident, fkName mode return errors.Trace(infoschema.ErrTableNotExists.GenByArgs(ti.Schema, ti.Name)) } - fkInfo, err := buildFKInfo(fkName, keys, refer) + fkInfo, err := buildFKInfo(fkName, keys, refer, t.Cols()) if err != nil { return errors.Trace(err) } diff --git a/ddl/ddl_db_test.go b/ddl/ddl_db_test.go index d3cabcad097df..a7ce45e60fd7e 100644 --- a/ddl/ddl_db_test.go +++ b/ddl/ddl_db_test.go @@ -1447,9 +1447,14 @@ func (s *testDBSuite) TestTableForeignKey(c *C) { s.tk = testkit.NewTestKit(c, s.store) s.tk.MustExec("use test") s.tk.MustExec("create table t1 (a int, b int);") + // test create table with foreign key. failSQL := "create table t2 (c int, foreign key (a) references t1(a));" s.testErrorCode(c, failSQL, tmysql.ErrKeyColumnDoesNotExits) - s.tk.MustExec("drop table if exists t1,t2;") + // test add foreign key. + s.tk.MustExec("create table t3 (a int, b int);") + failSQL = "alter table t1 add foreign key (c) REFERENCES t3(a);" + s.testErrorCode(c, failSQL, tmysql.ErrKeyColumnDoesNotExits) + s.tk.MustExec("drop table if exists t1,t2,t3;") } func (s *testDBSuite) TestCreateTableWithPartition(c *C) {