diff --git a/plan/plans/union.go b/plan/plans/union.go index a19677fa6e28d..c9069abb35f3b 100644 --- a/plan/plans/union.go +++ b/plan/plans/union.go @@ -44,7 +44,7 @@ func (p *UnionPlan) Explain(w format.Formatter) { // GetFields implements plan.Plan GetFields interface. func (p *UnionPlan) GetFields() []*field.ResultField { - return p.Srcs[0].GetFields() + return p.RFields } // Filter implements plan.Plan Filter interface. diff --git a/stmt/stmts/union.go b/stmt/stmts/union.go index e61c7dbc398ed..90390a7187784 100644 --- a/stmt/stmts/union.go +++ b/stmt/stmts/union.go @@ -14,7 +14,9 @@ package stmts import ( + "github.com/juju/errors" "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/plan/plans" "github.com/pingcap/tidb/rset" @@ -59,14 +61,71 @@ func (s *UnionStmt) SetText(text string) { // Plan implements the plan.Planner interface. func (s *UnionStmt) Plan(ctx context.Context) (plan.Plan, error) { srcs := make([]plan.Plan, 0, len(s.Selects)) + columnCount := 0 for _, s := range s.Selects { p, err := s.Plan(ctx) if err != nil { return nil, err } + if columnCount > 0 && columnCount != len(p.GetFields()) { + return nil, errors.New("The used SELECT statements have a different number of columns") + } + columnCount = len(p.GetFields()) + srcs = append(srcs, p) } - return &plans.UnionPlan{Srcs: srcs, Distincts: s.Distincts}, nil + + for i := len(s.Distincts) - 1; i >= 0; i-- { + if s.Distincts[i] { + // distinct overwrites all previous all + // e.g, select * from t1 union all select * from t2 union distinct select * from t3. + // The distinct will overwrite all for t1 and t2. + i-- + for ; i >= 0; i-- { + s.Distincts[i] = true + } + break + } + } + + fields := srcs[0].GetFields() + selectList := &plans.SelectList{} + selectList.ResultFields = make([]*field.ResultField, len(fields)) + selectList.HiddenFieldOffset = len(fields) + + // Union uses first select return column names and ignores table name. + // We only care result name and type here. + for i, f := range fields { + nf := &field.ResultField{} + nf.Name = f.Name + nf.FieldType = f.FieldType + selectList.ResultFields[i] = nf + } + + var ( + r plan.Plan + err error + ) + + r = &plans.UnionPlan{Srcs: srcs, Distincts: s.Distincts, RFields: selectList.ResultFields} + + if s := s.OrderBy; s != nil { + if r, err = (&rsets.OrderByRset{By: s.By, + Src: r, + SelectList: selectList, + }).Plan(ctx); err != nil { + return nil, err + } + } + + if s := s.Offset; s != nil { + r = &plans.OffsetDefaultPlan{Count: s.Count, Src: r, Fields: r.GetFields()} + } + if s := s.Limit; s != nil { + r = &plans.LimitDefaultPlan{Count: s.Count, Src: r, Fields: r.GetFields()} + } + + return r, nil } // Exec implements the stmt.Statement Exec interface. diff --git a/stmt/stmts/union_test.go b/stmt/stmts/union_test.go index 9bbd066c50eec..028f04929f8bf 100644 --- a/stmt/stmts/union_test.go +++ b/stmt/stmts/union_test.go @@ -14,6 +14,9 @@ package stmts_test import ( + "database/sql" + "fmt" + . "github.com/pingcap/check" "github.com/pingcap/tidb" "github.com/pingcap/tidb/stmt/stmts" @@ -45,16 +48,71 @@ func (s *testStmtSuite) TestUnion(c *C) { tx := mustBegin(c, s.testDB) rows, err := tx.Query(testSQL) c.Assert(err, IsNil) + matchRows(c, rows, [][]interface{}{{1}, {2}}) + + rows, err = tx.Query("select 1 union all select 1") + c.Assert(err, IsNil) + matchRows(c, rows, [][]interface{}{{1}, {1}}) + + rows, err = tx.Query("select 1 union all select 1 union select 1") + c.Assert(err, IsNil) + matchRows(c, rows, [][]interface{}{{1}}) + + rows, err = tx.Query("select 1 union (select 2) limit 1") + c.Assert(err, IsNil) + matchRows(c, rows, [][]interface{}{{1}}) + + rows, err = tx.Query("select 1 union (select 2) limit 1, 1") + c.Assert(err, IsNil) + matchRows(c, rows, [][]interface{}{{2}}) + + rows, err = tx.Query("select id from union_test union all (select 1) order by id desc") + c.Assert(err, IsNil) + matchRows(c, rows, [][]interface{}{{2}, {1}, {1}}) + + rows, err = tx.Query("select id as a from union_test union (select 1) order by a desc") + c.Assert(err, IsNil) + matchRows(c, rows, [][]interface{}{{2}, {1}}) + + mustCommit(c, tx) +} - i := 1 +func dumpRows(c *C, rows *sql.Rows) [][]interface{} { + cols, err := rows.Columns() + c.Assert(err, IsNil) + ay := make([][]interface{}, 0) for rows.Next() { - var id int - rows.Scan(&id) - c.Assert(id, Equals, i) + v := make([]interface{}, len(cols)) + for i := range v { + v[i] = new(interface{}) + } + err = rows.Scan(v...) + c.Assert(err, IsNil) - i++ + for i := range v { + v[i] = *(v[i].(*interface{})) + } + ay = append(ay, v) } rows.Close() - mustCommit(c, tx) + c.Assert(rows.Err(), IsNil) + return ay +} + +func matchRows(c *C, rows *sql.Rows, expected [][]interface{}) { + ay := dumpRows(c, rows) + c.Assert(len(ay), Equals, len(expected)) + for i := range ay { + match(c, ay[i], expected[i]...) + } +} + +func match(c *C, row []interface{}, expected ...interface{}) { + c.Assert(len(row), Equals, len(expected)) + for i := range row { + got := fmt.Sprintf("%v", row[i]) + need := fmt.Sprintf("%v", expected[i]) + c.Assert(got, Equals, need) + } }