Skip to content

Commit

Permalink
*: union support order by, limit and different all/distinct.
Browse files Browse the repository at this point in the history
  • Loading branch information
siddontang committed Oct 7, 2015
1 parent 4cc7351 commit 27996c7
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 8 deletions.
2 changes: 1 addition & 1 deletion plan/plans/union.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (p *UnionPlan) Explain(w format.Formatter) {

// GetFields implements plan.Plan GetFields interface.
func (p *UnionPlan) GetFields() []*field.ResultField {
return p.Srcs[0].GetFields()
return p.RFields
}

// Filter implements plan.Plan Filter interface.
Expand Down
61 changes: 60 additions & 1 deletion stmt/stmts/union.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
package stmts

import (
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/field"
"github.com/pingcap/tidb/plan"
"github.com/pingcap/tidb/plan/plans"
"github.com/pingcap/tidb/rset"
Expand Down Expand Up @@ -59,14 +61,71 @@ func (s *UnionStmt) SetText(text string) {
// Plan implements the plan.Planner interface.
func (s *UnionStmt) Plan(ctx context.Context) (plan.Plan, error) {
srcs := make([]plan.Plan, 0, len(s.Selects))
columnCount := 0
for _, s := range s.Selects {
p, err := s.Plan(ctx)
if err != nil {
return nil, err
}
if columnCount > 0 && columnCount != len(p.GetFields()) {
return nil, errors.New("The used SELECT statements have a different number of columns")
}
columnCount = len(p.GetFields())

srcs = append(srcs, p)
}
return &plans.UnionPlan{Srcs: srcs, Distincts: s.Distincts}, nil

for i := len(s.Distincts) - 1; i >= 0; i-- {
if s.Distincts[i] {
// distinct overwrites all previous all
// e.g, select * from t1 union all select * from t2 union distinct select * from t3.
// The distinct will overwrite all for t1 and t2.
i--
for ; i >= 0; i-- {
s.Distincts[i] = true
}
break
}
}

fields := srcs[0].GetFields()
selectList := &plans.SelectList{}
selectList.ResultFields = make([]*field.ResultField, len(fields))
selectList.HiddenFieldOffset = len(fields)

// Union uses first select return column names and ignores table name.
// We only care result name and type here.
for i, f := range fields {
nf := &field.ResultField{}
nf.Name = f.Name
nf.FieldType = f.FieldType
selectList.ResultFields[i] = nf
}

var (
r plan.Plan
err error
)

r = &plans.UnionPlan{Srcs: srcs, Distincts: s.Distincts, RFields: selectList.ResultFields}

if s := s.OrderBy; s != nil {
if r, err = (&rsets.OrderByRset{By: s.By,
Src: r,
SelectList: selectList,
}).Plan(ctx); err != nil {
return nil, err
}
}

if s := s.Offset; s != nil {
r = &plans.OffsetDefaultPlan{Count: s.Count, Src: r, Fields: r.GetFields()}
}
if s := s.Limit; s != nil {
r = &plans.LimitDefaultPlan{Count: s.Count, Src: r, Fields: r.GetFields()}
}

return r, nil
}

// Exec implements the stmt.Statement Exec interface.
Expand Down
70 changes: 64 additions & 6 deletions stmt/stmts/union_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
package stmts_test

import (
"database/sql"
"fmt"

. "github.com/pingcap/check"
"github.com/pingcap/tidb"
"github.com/pingcap/tidb/stmt/stmts"
Expand Down Expand Up @@ -45,16 +48,71 @@ func (s *testStmtSuite) TestUnion(c *C) {
tx := mustBegin(c, s.testDB)
rows, err := tx.Query(testSQL)
c.Assert(err, IsNil)
matchRows(c, rows, [][]interface{}{{1}, {2}})

rows, err = tx.Query("select 1 union all select 1")
c.Assert(err, IsNil)
matchRows(c, rows, [][]interface{}{{1}, {1}})

rows, err = tx.Query("select 1 union all select 1 union select 1")
c.Assert(err, IsNil)
matchRows(c, rows, [][]interface{}{{1}})

rows, err = tx.Query("select 1 union (select 2) limit 1")
c.Assert(err, IsNil)
matchRows(c, rows, [][]interface{}{{1}})

rows, err = tx.Query("select 1 union (select 2) limit 1, 1")
c.Assert(err, IsNil)
matchRows(c, rows, [][]interface{}{{2}})

rows, err = tx.Query("select id from union_test union all (select 1) order by id desc")
c.Assert(err, IsNil)
matchRows(c, rows, [][]interface{}{{2}, {1}, {1}})

rows, err = tx.Query("select id as a from union_test union (select 1) order by a desc")
c.Assert(err, IsNil)
matchRows(c, rows, [][]interface{}{{2}, {1}})

mustCommit(c, tx)
}

i := 1
func dumpRows(c *C, rows *sql.Rows) [][]interface{} {
cols, err := rows.Columns()
c.Assert(err, IsNil)
ay := make([][]interface{}, 0)
for rows.Next() {
var id int
rows.Scan(&id)
c.Assert(id, Equals, i)
v := make([]interface{}, len(cols))
for i := range v {
v[i] = new(interface{})
}
err = rows.Scan(v...)
c.Assert(err, IsNil)

i++
for i := range v {
v[i] = *(v[i].(*interface{}))
}
ay = append(ay, v)
}

rows.Close()
mustCommit(c, tx)
c.Assert(rows.Err(), IsNil)
return ay
}

func matchRows(c *C, rows *sql.Rows, expected [][]interface{}) {
ay := dumpRows(c, rows)
c.Assert(len(ay), Equals, len(expected))
for i := range ay {
match(c, ay[i], expected[i]...)
}
}

func match(c *C, row []interface{}, expected ...interface{}) {
c.Assert(len(row), Equals, len(expected))
for i := range row {
got := fmt.Sprintf("%v", row[i])
need := fmt.Sprintf("%v", expected[i])
c.Assert(got, Equals, need)
}
}

1 comment on commit 27996c7

@qiuyesuifeng
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Please sign in to comment.