diff --git a/executor/update_test.go b/executor/update_test.go index 4bcba846934b3..405ac69a6faf1 100644 --- a/executor/update_test.go +++ b/executor/update_test.go @@ -216,3 +216,14 @@ func (s *testUpdateSuite) TestUpdateWithAutoidSchema(c *C) { tk.MustQuery(tt.query).Check(tt.result) } } + +func (s *testUpdateSuite) TestUpdateWithSubquery(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create table t1(id varchar(30) not null, status varchar(1) not null default 'N', id2 varchar(30))") + tk.MustExec("create table t2(id varchar(30) not null, field varchar(4) not null)") + tk.MustExec("insert into t1 values('abc', 'F', 'abc')") + tk.MustExec("insert into t2 values('abc', 'MAIN')") + tk.MustExec("update t1 set status = 'N' where status = 'F' and (id in (select id from t2 where field = 'MAIN') or id2 in (select id from t2 where field = 'main'))") + tk.MustQuery("select * from t1").Check(testkit.Rows("abc N abc")) +} diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 6b78d16a8dcaf..e4d69f56c0244 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -2263,12 +2263,14 @@ func (b *planBuilder) buildUpdate(update *ast.UpdateStmt) (Plan, error) { b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SelectPriv, dbName, t.Name.L, "") } + oldSchema := p.Schema().Clone() if sel.Where != nil { p, err = b.buildSelection(p, sel.Where, nil) if err != nil { return nil, errors.Trace(err) } } + if sel.OrderBy != nil { p, err = b.buildSort(p, sel.OrderBy.Items, nil) if err != nil { @@ -2281,6 +2283,13 @@ func (b *planBuilder) buildUpdate(update *ast.UpdateStmt) (Plan, error) { return nil, errors.Trace(err) } } + // TODO: expression rewriter should not change the output columns. We should cut the columns here. + if p.Schema().Len() != oldSchema.Len() { + proj := LogicalProjection{Exprs: expression.Column2Exprs(oldSchema.Columns)}.init(b.ctx) + proj.SetSchema(oldSchema) + proj.SetChildren(p) + p = proj + } orderedList, np, err := b.buildUpdateLists(tableList, update.List, p) if err != nil { return nil, errors.Trace(err) diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index 810a85db37d0c..018d6cbb26edf 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -955,6 +955,10 @@ func (s *testPlanSuite) TestPlanBuilder(c *C) { // binlog columns, because the schema and data are not consistent. plan: "LeftHashJoin{LeftHashJoin{TableReader(Table(t))->IndexLookUp(Index(t.c_d_e)[[666,666]], Table(t))}(test.t.a,test.t.b)->IndexReader(Index(t.c_d_e)[[42,42]])}(test.t.b,test.t.a)->Sel([or(6_aux_0, 10_aux_0)])->Projection->Delete", }, + { + sql: "update t set a = 2 where b in (select c from t)", + plan: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t.b,test.t.c)->Update", + }, } for _, ca := range tests { comment := Commentf("for %s", ca.sql)