Skip to content

Commit

Permalink
*: rewrite the tableFilter process flow radondb#480
Browse files Browse the repository at this point in the history
[summary]
Refactor the joinnode code about tableFilter by using CloneExpr.
[test case]
src/executor/engine/join_engine_test.go
src/planner/builder/builder_test.go
src/planner/builder/from_test.go
src/planner/select_plan_test.go
[patch codecov]
src/planner/builder 97.8%
  • Loading branch information
zhyass committed Nov 4, 2019
1 parent 13ca41a commit 6dee0e4
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 143 deletions.
6 changes: 3 additions & 3 deletions src/executor/engine/join_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,8 @@ func TestJoinEngine(t *testing.T) {
fakedbs.AddQuery("select B.name, B.id from sbtest.B1 as B order by B.name asc", r2)
fakedbs.AddQuery("select B.name from sbtest.B1 as B where B.id = 1 order by B.name asc", r22)
fakedbs.AddQuery("select B.name, B.id from sbtest.B0 as B where B.id = 0", r2)
fakedbs.AddQuery("select B.name, B.id from sbtest.B0 as B where B.name = 's' and B.id > 2 order by B.id asc", r21)
fakedbs.AddQuery("select B.name, B.id from sbtest.B1 as B where B.name = 's' and B.id > 2 order by B.id asc", r21)
fakedbs.AddQuery("select B.name, B.id from sbtest.B0 as B where B.id > 2 and B.name = 's' order by B.id asc", r21)
fakedbs.AddQuery("select B.name, B.id from sbtest.B1 as B where B.id > 2 and B.name = 's' order by B.id asc", r21)
fakedbs.AddQuery("select B.name, B.id from sbtest.B0 as B where B.id > 2 order by B.id asc", r21)
fakedbs.AddQuery("select B.name, B.id from sbtest.B1 as B where B.id > 2 order by B.id asc", r2)
fakedbs.AddQuery("select /*+nested+*/ B.name, B.id from sbtest.B1 as B where B.id = 1 and 'go' = B.name", r21)
Expand Down Expand Up @@ -483,7 +483,7 @@ func TestMaxRowErr(t *testing.T) {
// desc
fakedbs.AddQuery("select a.id, a.name from sbtest.a8 as a where a.id = 3 order by a.id asc", r1)
fakedbs.AddQuery("select /*+nested+*/ a.id, a.name from sbtest.a8 as a where a.id = 3", r1)
fakedbs.AddQuery("select /*+nested+*/ b.id, b.name from sbtest.b1 as b where 3 = b.id and b.id = 3", r1)
fakedbs.AddQuery("select /*+nested+*/ b.id, b.name from sbtest.b1 as b where b.id = 3 and 3 = b.id", r1)
fakedbs.AddQueryPattern("select b.id, b.name from .*", r2)
fakedbs.AddQueryPattern("select b.name, b.id from .*", r3)
fakedbs.AddQueryPattern("select s.id, s.name from .*", r1)
Expand Down
6 changes: 3 additions & 3 deletions src/planner/builder/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func TestProcessSelect(t *testing.T) {
Range: "[512-4096)",
},
{
Query: "select B.id from sbtest.B1 as B where 1 = 1 and B.b = 2 and B.id = 1 order by B.id asc",
Query: "select B.id from sbtest.B1 as B where B.id = 1 and 1 = 1 and B.b = 2 order by B.id asc",
Backend: "backend2",
Range: "[512-4096)",
}},
Expand All @@ -169,7 +169,7 @@ func TestProcessSelect(t *testing.T) {
Range: "[512-4096)",
},
{
Query: "select /*+nested+*/ 1 from sbtest.B1 as B where :A_id = B.id and B.id = 1",
Query: "select /*+nested+*/ 1 from sbtest.B1 as B where B.id = 1 and :A_id = B.id",
Backend: "backend2",
Range: "[512-4096)",
}},
Expand Down Expand Up @@ -267,7 +267,7 @@ func TestProcessSelect(t *testing.T) {
project: "sum(A.a), b",
out: []xcontext.QueryTuple{
{
Query: "select sum(A.a), S.b from sbtest.A1 as A join sbtest.S on A.id = S.id where A.id = 0 group by S.b",
Query: "select sum(A.a), S.b from sbtest.A1 as A join sbtest.S on A.id = S.id where A.id = 0 and S.id = 0 group by S.b",
Backend: "backend1",
Range: "[0-32)",
}},
Expand Down
7 changes: 2 additions & 5 deletions src/planner/builder/from_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ func TestScanTableExprs(t *testing.T) {
}
assert.Equal(t, 1, len(j.joinOn))
assert.False(t, j.IsLeftJoin)
assert.Equal(t, 1, len(j.tableFilter))

tbMaps := j.getReferTables()
tbInfo := tbMaps["A"]
Expand All @@ -107,7 +106,6 @@ func TestScanTableExprs(t *testing.T) {
}
assert.Equal(t, 1, len(j.joinOn))
assert.True(t, j.IsLeftJoin)
assert.Equal(t, 0, len(j.tableFilter))

tbMaps := j.getReferTables()
tbInfo := tbMaps["A"]
Expand Down Expand Up @@ -139,7 +137,6 @@ func TestScanTableExprs(t *testing.T) {
}
assert.Equal(t, 1, len(j.joinOn))
assert.True(t, j.IsLeftJoin)
assert.Equal(t, 0, len(j.tableFilter))
assert.NotNil(t, j.otherJoinOn)

err = j.pushOtherJoin()
Expand All @@ -160,7 +157,6 @@ func TestScanTableExprs(t *testing.T) {
}
assert.Equal(t, 1, len(j.joinOn))
assert.True(t, j.IsLeftJoin)
assert.Equal(t, 0, len(j.tableFilter))

tbMaps := j.getReferTables()
tbInfo := tbMaps["A"]
Expand Down Expand Up @@ -257,7 +253,6 @@ func TestScanTableExprs(t *testing.T) {
assert.Equal(t, 2, len(j.joinOn))
assert.False(t, j.IsLeftJoin)
assert.Equal(t, 1, len(j.noTableFilter))
assert.Equal(t, 1, len(j.tableFilter))
tbMaps := j.getReferTables()
tbInfo := tbMaps["A"]

Expand Down Expand Up @@ -366,6 +361,7 @@ func TestScanTableExprsError(t *testing.T) {
"select * from G join A on G.id=A.id join B on A.a=G.a",
"select * from G join (A,B) on G.id=A.id and A.a=B.a",
"select * from G join A as G where G.id=1",
"select * from A join B on A.id=B.id and B.id=0x12",
}
wants := []string{
"Table 'C' doesn't exist (errno 1146) (sqlstate 42S02)",
Expand All @@ -381,6 +377,7 @@ func TestScanTableExprsError(t *testing.T) {
"unsupported: join.on.condition.should.cross.left-right.tables",
"unsupported: join.on.condition.should.cross.left-right.tables",
"unsupported: not.unique.table.or.alias:'G'",
"hash.unsupported.key.type:[3]",
}
log := xlog.NewStdLog(xlog.Level(xlog.PANIC))
database := "sbtest"
Expand Down
190 changes: 71 additions & 119 deletions src/planner/builder/join_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,14 @@ type JoinNode struct {
// eg: t1 join t2 on t1.a>t2.a, 't1.a>t2.a' parser into CmpFilter.
CmpFilter []Comparison
/*
* eg: 't1 left join t2 on t1.a=t2.a and t1.b=2' where t1.c=t2.c and 1=1 and t2.b>2 where
* t2.str is null. 't1.b=2' will parser into otherJoinOn, IsLeftJoin is true, 't1.c=t2.c'
* parser into otherFilter, else into joinOn. '1=1' parser into noTableFilter. 't2.b>2' into
* tableFilter. 't2.str is null' into rightNull.
* eg: 't1 left join t2 on t1.a=t2.a and t1.b=2' where t1.c=t2.c and 1=1 and t2.b>2 where t2.str is null.
* 't1.b=2' will parser into otherJoinOn, IsLeftJoin is true, 't1.c=t2.c' parser into otherFilter, else
* into joinOn. '1=1' parser into noTableFilter. 't2.str is null' into rightNull.
*/
tableFilter, otherFilter []exprInfo
noTableFilter []sqlparser.Expr
otherJoinOn *otherJoin
rightNull []selectTuple
otherFilter []exprInfo
noTableFilter []sqlparser.Expr
otherJoinOn *otherJoin
rightNull []selectTuple
// whether is left join.
IsLeftJoin bool
// whether the right node has filters in left join.
Expand All @@ -100,11 +99,6 @@ type JoinNode struct {
LeftTmpCols []int
// record the `rightNull`'s index in right.fields.
RightTmpCols []int
// keyFilters based on LeftKeys、RightKeys and tableFilter.
// eg: select * from t1 join t2 on t1.a=t2.a where t1.a=1
// `t1.a` in LeftKeys, `t1.a=1` in tableFilter. in the map,
// key is 0(index is 0), value is tableFilter(`t1.a=1`).
keyFilters map[int][]exprInfo
// isHint defines whether has /*+nested+*/.
isHint bool
order int
Expand All @@ -127,7 +121,6 @@ func newJoinNode(log *xlog.Log, Left, Right SelectNode, router *router.Router, j
router: router,
joinExpr: joinExpr,
joinOn: joinOn,
keyFilters: make(map[int][]exprInfo),
Vars: make(map[string]int),
referTables: referTables,
IsLeftJoin: isLeftJoin,
Expand Down Expand Up @@ -171,15 +164,8 @@ func (j *JoinNode) pushFilter(filters []exprInfo) error {
if len(filter.cols) != 1 {
tbInfo.parent.setWhereFilter(filter)
} else {
j.tableFilter = append(j.tableFilter, filter)
if len(filter.vals) > 0 && tbInfo.shardKey != "" {
if nameMatch(filter.cols[0], tb, tbInfo.shardKey) {
for _, val := range filter.vals {
if err = getIndex(j.router, tbInfo, val); err != nil {
return err
}
}
}
if err := j.pushKeyFilter(filter, filter.cols[0].Qualifier.Name.String(), filter.cols[0].Name.String()); err != nil {
return err
}
}
} else {
Expand Down Expand Up @@ -208,6 +194,68 @@ func (j *JoinNode) pushFilter(filters []exprInfo) error {
return err
}

// pushKeyFilter used to build the keyFilter based on the tableFilter and joinOn.
// eg: select t1.a,t2.a from t1 join t2 on t1.a=t2.a where t1.a=1;
// push: select t1.a from t1 where t1.a=1 order by t1.a asc;
// select t2.a from t2 where t2.a=1 order by t2.a asc;
func (j *JoinNode) pushKeyFilter(filter exprInfo, table, field string) error {
var tb, col string
var err error
find := false
if _, ok := j.Left.getReferTables()[table]; ok {
for _, join := range j.joinOn {
lt := join.cols[0].Qualifier.Name.String()
lc := join.cols[0].Name.String()
if lt == table && lc == field {
tb = join.cols[1].Qualifier.Name.String()
col = join.cols[1].Name.String()
find = true
break
}
}

if err = j.Left.pushKeyFilter(filter, table, field); err != nil {
return err
}

if find {
// replace the colname.
origin := *(filter.cols[0])
filter.cols[0].Name = sqlparser.NewColIdent(col)
filter.cols[0].Qualifier = sqlparser.TableName{Name: sqlparser.NewTableIdent(tb)}
if err = j.Right.pushKeyFilter(filter, tb, col); err != nil {
return err
}
// recovery the colname in exprisson.
*(filter.cols[0]) = origin
}
} else {
for _, join := range j.joinOn {
rt := join.cols[1].Qualifier.Name.String()
rc := join.cols[1].Name.String()
if rt == table && rc == field {
tb = join.cols[0].Qualifier.Name.String()
col = join.cols[0].Name.String()
find = true
break
}
}
if err = j.Right.pushKeyFilter(filter, table, field); err != nil {
return err
}
if find {
origin := *(filter.cols[0])
filter.cols[0].Name = sqlparser.NewColIdent(col)
filter.cols[0].Qualifier = sqlparser.TableName{Name: sqlparser.NewTableIdent(tb)}
if err = j.Left.pushKeyFilter(filter, tb, col); err != nil {
return err
}
*(filter.cols[0]) = origin
}
}
return nil
}

// setParent set the parent node.
func (j *JoinNode) setParent(p SelectNode) {
j.parent = p
Expand Down Expand Up @@ -344,9 +392,6 @@ func (j *JoinNode) pushEqualCmpr(joins []exprInfo) SelectNode {
mn.setParent(node.parent)
mn.setParenthese(node.hasParen)

for _, filter := range node.tableFilter {
mn.setWhereFilter(filter)
}
for _, filter := range node.otherFilter {
mn.setWhereFilter(filter)
}
Expand Down Expand Up @@ -393,12 +438,6 @@ func (j *JoinNode) pushEqualCmpr(joins []exprInfo) SelectNode {
// calcRoute used to calc the route.
func (j *JoinNode) calcRoute() (SelectNode, error) {
var err error
for _, filter := range j.tableFilter {
if !j.buildKeyFilter(filter, false) {
tbInfo := j.referTables[filter.referTables[0]]
tbInfo.parent.setWhereFilter(filter)
}
}
if j.Left, err = j.Left.calcRoute(); err != nil {
return j, err
}
Expand All @@ -421,11 +460,6 @@ func (j *JoinNode) calcRoute() (SelectNode, error) {
for _, filter := range j.otherFilter {
mn.setWhereFilter(filter)
}
for _, filters := range j.keyFilters {
for _, filter := range filters {
mn.setWhereFilter(filter)
}
}
mn.setNoTableFilter(j.noTableFilter)
if j.joinExpr == nil && len(j.joinOn) > 0 {
mn.pushEqualCmpr(j.joinOn)
Expand All @@ -438,68 +472,6 @@ func (j *JoinNode) calcRoute() (SelectNode, error) {
return j, nil
}

// buildKeyFilter used to build the keyFilter based on the tableFilter and joinOn.
// eg: select t1.a,t2.a from t1 join t2 on t1.a=t2.a where t1.a=1;
// push: select t1.a from t1 where t1.a=1 order by t1.a asc;
// select t2.a from t2 where t2.a=1 order by t2.a asc;
func (j *JoinNode) buildKeyFilter(filter exprInfo, isFind bool) bool {
table := filter.cols[0].Qualifier.Name.String()
field := filter.cols[0].Name.String()
find := false
if _, ok := j.Left.getReferTables()[filter.referTables[0]]; ok {
for i, join := range j.joinOn {
lt := join.cols[0].Qualifier.Name.String()
lc := join.cols[0].Name.String()
if lt == table && lc == field {
j.keyFilters[i] = append(j.keyFilters[i], filter)
if len(filter.vals) > 0 {
rt := join.cols[1].Qualifier.Name.String()
rc := join.cols[1].Name.String()
tbInfo := j.referTables[rt]
if tbInfo.shardKey == rc {
for _, val := range filter.vals {
if err := getIndex(j.router, tbInfo, val); err != nil {
panic(err)
}
}
}
}
find = true
break
}
}
if jn, ok := j.Left.(*JoinNode); ok {
return jn.buildKeyFilter(filter, find || isFind)
}
} else {
for i, join := range j.joinOn {
rt := join.cols[1].Qualifier.Name.String()
rc := join.cols[1].Name.String()
if rt == table && rc == field {
j.keyFilters[i] = append(j.keyFilters[i], filter)
if len(filter.vals) > 0 {
lt := join.cols[0].Qualifier.Name.String()
lc := join.cols[0].Name.String()
tbInfo := j.referTables[lt]
if tbInfo.shardKey == lc {
for _, val := range filter.vals {
if err := getIndex(j.router, tbInfo, val); err != nil {
panic(err)
}
}
}
}
find = true
break
}
}
if jn, ok := j.Right.(*JoinNode); ok {
return jn.buildKeyFilter(filter, find || isFind)
}
}
return find || isFind
}

// pushSelectExprs used to push the select fields.
func (j *JoinNode) pushSelectExprs(fields, groups []selectTuple, sel *sqlparser.Select, aggTyp aggrType) error {
if j.isHint {
Expand Down Expand Up @@ -872,29 +844,9 @@ func (j *JoinNode) buildQuery(tbInfos map[string]*tableInfo) {
}
}

for i, filters := range j.keyFilters {
table := j.RightKeys[i].Table
field := j.RightKeys[i].Field
tbInfo := j.referTables[table]
for _, filter := range filters {
filter.cols[0].Qualifier.Name = sqlparser.NewTableIdent(table)
filter.cols[0].Name = sqlparser.NewColIdent(field)
tbInfo.parent.filters[filter.expr] = 0
}
}
j.Right.setNoTableFilter(j.noTableFilter)
j.Right.buildQuery(tbInfos)

for i, filters := range j.keyFilters {
table := j.LeftKeys[i].Table
field := j.LeftKeys[i].Field
tbInfo := j.referTables[table]
for _, filter := range filters {
filter.cols[0].Qualifier.Name = sqlparser.NewTableIdent(table)
filter.cols[0].Name = sqlparser.NewColIdent(field)
tbInfo.parent.filters[filter.expr] = 0
}
}
j.Left.setNoTableFilter(j.noTableFilter)
j.Left.buildQuery(tbInfos)
}
Expand Down
Loading

0 comments on commit 6dee0e4

Please sign in to comment.