From 6dee0e4c2e0ad92673aae1ceb86df4fb382a2628 Mon Sep 17 00:00:00 2001 From: zhyass <34016424+zhyass@users.noreply.github.com> Date: Mon, 4 Nov 2019 17:03:53 +0800 Subject: [PATCH] *: rewrite the tableFilter process flow #480 [summary] Refactor the joinnode code about tableFilter by using CloneExpr. [test case] src/executor/engine/join_engine_test.go src/planner/builder/builder_test.go src/planner/builder/from_test.go src/planner/select_plan_test.go [patch codecov] src/planner/builder 97.8% --- src/executor/engine/join_engine_test.go | 6 +- src/planner/builder/builder_test.go | 6 +- src/planner/builder/from_test.go | 7 +- src/planner/builder/join_node.go | 190 +++++++++--------------- src/planner/builder/merge_node.go | 26 ++-- src/planner/builder/plan_node.go | 1 + src/planner/select_plan_test.go | 6 +- 7 files changed, 99 insertions(+), 143 deletions(-) diff --git a/src/executor/engine/join_engine_test.go b/src/executor/engine/join_engine_test.go index 605116c5..60bff9af 100644 --- a/src/executor/engine/join_engine_test.go +++ b/src/executor/engine/join_engine_test.go @@ -274,8 +274,8 @@ func TestJoinEngine(t *testing.T) { fakedbs.AddQuery("select B.name, B.id from sbtest.B1 as B order by B.name asc", r2) fakedbs.AddQuery("select B.name from sbtest.B1 as B where B.id = 1 order by B.name asc", r22) fakedbs.AddQuery("select B.name, B.id from sbtest.B0 as B where B.id = 0", r2) - fakedbs.AddQuery("select B.name, B.id from sbtest.B0 as B where B.name = 's' and B.id > 2 order by B.id asc", r21) - fakedbs.AddQuery("select B.name, B.id from sbtest.B1 as B where B.name = 's' and B.id > 2 order by B.id asc", r21) + fakedbs.AddQuery("select B.name, B.id from sbtest.B0 as B where B.id > 2 and B.name = 's' order by B.id asc", r21) + fakedbs.AddQuery("select B.name, B.id from sbtest.B1 as B where B.id > 2 and B.name = 's' order by B.id asc", r21) fakedbs.AddQuery("select B.name, B.id from sbtest.B0 as B where B.id > 2 order by B.id asc", r21) fakedbs.AddQuery("select B.name, B.id from sbtest.B1 as B where B.id > 2 order by B.id asc", r2) fakedbs.AddQuery("select /*+nested+*/ B.name, B.id from sbtest.B1 as B where B.id = 1 and 'go' = B.name", r21) @@ -483,7 +483,7 @@ func TestMaxRowErr(t *testing.T) { // desc fakedbs.AddQuery("select a.id, a.name from sbtest.a8 as a where a.id = 3 order by a.id asc", r1) fakedbs.AddQuery("select /*+nested+*/ a.id, a.name from sbtest.a8 as a where a.id = 3", r1) - fakedbs.AddQuery("select /*+nested+*/ b.id, b.name from sbtest.b1 as b where 3 = b.id and b.id = 3", r1) + fakedbs.AddQuery("select /*+nested+*/ b.id, b.name from sbtest.b1 as b where b.id = 3 and 3 = b.id", r1) fakedbs.AddQueryPattern("select b.id, b.name from .*", r2) fakedbs.AddQueryPattern("select b.name, b.id from .*", r3) fakedbs.AddQueryPattern("select s.id, s.name from .*", r1) diff --git a/src/planner/builder/builder_test.go b/src/planner/builder/builder_test.go index 683eb325..b158ec6b 100644 --- a/src/planner/builder/builder_test.go +++ b/src/planner/builder/builder_test.go @@ -153,7 +153,7 @@ func TestProcessSelect(t *testing.T) { Range: "[512-4096)", }, { - Query: "select B.id from sbtest.B1 as B where 1 = 1 and B.b = 2 and B.id = 1 order by B.id asc", + Query: "select B.id from sbtest.B1 as B where B.id = 1 and 1 = 1 and B.b = 2 order by B.id asc", Backend: "backend2", Range: "[512-4096)", }}, @@ -169,7 +169,7 @@ func TestProcessSelect(t *testing.T) { Range: "[512-4096)", }, { - Query: "select /*+nested+*/ 1 from sbtest.B1 as B where :A_id = B.id and B.id = 1", + Query: "select /*+nested+*/ 1 from sbtest.B1 as B where B.id = 1 and :A_id = B.id", Backend: "backend2", Range: "[512-4096)", }}, @@ -267,7 +267,7 @@ func TestProcessSelect(t *testing.T) { project: "sum(A.a), b", out: []xcontext.QueryTuple{ { - Query: "select sum(A.a), S.b from sbtest.A1 as A join sbtest.S on A.id = S.id where A.id = 0 group by S.b", + Query: "select sum(A.a), S.b from sbtest.A1 as A join sbtest.S on A.id = S.id where A.id = 0 and S.id = 0 group by S.b", Backend: "backend1", Range: "[0-32)", }}, diff --git a/src/planner/builder/from_test.go b/src/planner/builder/from_test.go index 506fe128..b6414789 100644 --- a/src/planner/builder/from_test.go +++ b/src/planner/builder/from_test.go @@ -80,7 +80,6 @@ func TestScanTableExprs(t *testing.T) { } assert.Equal(t, 1, len(j.joinOn)) assert.False(t, j.IsLeftJoin) - assert.Equal(t, 1, len(j.tableFilter)) tbMaps := j.getReferTables() tbInfo := tbMaps["A"] @@ -107,7 +106,6 @@ func TestScanTableExprs(t *testing.T) { } assert.Equal(t, 1, len(j.joinOn)) assert.True(t, j.IsLeftJoin) - assert.Equal(t, 0, len(j.tableFilter)) tbMaps := j.getReferTables() tbInfo := tbMaps["A"] @@ -139,7 +137,6 @@ func TestScanTableExprs(t *testing.T) { } assert.Equal(t, 1, len(j.joinOn)) assert.True(t, j.IsLeftJoin) - assert.Equal(t, 0, len(j.tableFilter)) assert.NotNil(t, j.otherJoinOn) err = j.pushOtherJoin() @@ -160,7 +157,6 @@ func TestScanTableExprs(t *testing.T) { } assert.Equal(t, 1, len(j.joinOn)) assert.True(t, j.IsLeftJoin) - assert.Equal(t, 0, len(j.tableFilter)) tbMaps := j.getReferTables() tbInfo := tbMaps["A"] @@ -257,7 +253,6 @@ func TestScanTableExprs(t *testing.T) { assert.Equal(t, 2, len(j.joinOn)) assert.False(t, j.IsLeftJoin) assert.Equal(t, 1, len(j.noTableFilter)) - assert.Equal(t, 1, len(j.tableFilter)) tbMaps := j.getReferTables() tbInfo := tbMaps["A"] @@ -366,6 +361,7 @@ func TestScanTableExprsError(t *testing.T) { "select * from G join A on G.id=A.id join B on A.a=G.a", "select * from G join (A,B) on G.id=A.id and A.a=B.a", "select * from G join A as G where G.id=1", + "select * from A join B on A.id=B.id and B.id=0x12", } wants := []string{ "Table 'C' doesn't exist (errno 1146) (sqlstate 42S02)", @@ -381,6 +377,7 @@ func TestScanTableExprsError(t *testing.T) { "unsupported: join.on.condition.should.cross.left-right.tables", "unsupported: join.on.condition.should.cross.left-right.tables", "unsupported: not.unique.table.or.alias:'G'", + "hash.unsupported.key.type:[3]", } log := xlog.NewStdLog(xlog.Level(xlog.PANIC)) database := "sbtest" diff --git a/src/planner/builder/join_node.go b/src/planner/builder/join_node.go index 5156f49c..4881cf0f 100644 --- a/src/planner/builder/join_node.go +++ b/src/planner/builder/join_node.go @@ -83,15 +83,14 @@ type JoinNode struct { // eg: t1 join t2 on t1.a>t2.a, 't1.a>t2.a' parser into CmpFilter. CmpFilter []Comparison /* - * eg: 't1 left join t2 on t1.a=t2.a and t1.b=2' where t1.c=t2.c and 1=1 and t2.b>2 where - * t2.str is null. 't1.b=2' will parser into otherJoinOn, IsLeftJoin is true, 't1.c=t2.c' - * parser into otherFilter, else into joinOn. '1=1' parser into noTableFilter. 't2.b>2' into - * tableFilter. 't2.str is null' into rightNull. + * eg: 't1 left join t2 on t1.a=t2.a and t1.b=2' where t1.c=t2.c and 1=1 and t2.b>2 where t2.str is null. + * 't1.b=2' will parser into otherJoinOn, IsLeftJoin is true, 't1.c=t2.c' parser into otherFilter, else + * into joinOn. '1=1' parser into noTableFilter. 't2.str is null' into rightNull. */ - tableFilter, otherFilter []exprInfo - noTableFilter []sqlparser.Expr - otherJoinOn *otherJoin - rightNull []selectTuple + otherFilter []exprInfo + noTableFilter []sqlparser.Expr + otherJoinOn *otherJoin + rightNull []selectTuple // whether is left join. IsLeftJoin bool // whether the right node has filters in left join. @@ -100,11 +99,6 @@ type JoinNode struct { LeftTmpCols []int // record the `rightNull`'s index in right.fields. RightTmpCols []int - // keyFilters based on LeftKeys、RightKeys and tableFilter. - // eg: select * from t1 join t2 on t1.a=t2.a where t1.a=1 - // `t1.a` in LeftKeys, `t1.a=1` in tableFilter. in the map, - // key is 0(index is 0), value is tableFilter(`t1.a=1`). - keyFilters map[int][]exprInfo // isHint defines whether has /*+nested+*/. isHint bool order int @@ -127,7 +121,6 @@ func newJoinNode(log *xlog.Log, Left, Right SelectNode, router *router.Router, j router: router, joinExpr: joinExpr, joinOn: joinOn, - keyFilters: make(map[int][]exprInfo), Vars: make(map[string]int), referTables: referTables, IsLeftJoin: isLeftJoin, @@ -171,15 +164,8 @@ func (j *JoinNode) pushFilter(filters []exprInfo) error { if len(filter.cols) != 1 { tbInfo.parent.setWhereFilter(filter) } else { - j.tableFilter = append(j.tableFilter, filter) - if len(filter.vals) > 0 && tbInfo.shardKey != "" { - if nameMatch(filter.cols[0], tb, tbInfo.shardKey) { - for _, val := range filter.vals { - if err = getIndex(j.router, tbInfo, val); err != nil { - return err - } - } - } + if err := j.pushKeyFilter(filter, filter.cols[0].Qualifier.Name.String(), filter.cols[0].Name.String()); err != nil { + return err } } } else { @@ -208,6 +194,68 @@ func (j *JoinNode) pushFilter(filters []exprInfo) error { return err } +// pushKeyFilter used to build the keyFilter based on the tableFilter and joinOn. +// eg: select t1.a,t2.a from t1 join t2 on t1.a=t2.a where t1.a=1; +// push: select t1.a from t1 where t1.a=1 order by t1.a asc; +// select t2.a from t2 where t2.a=1 order by t2.a asc; +func (j *JoinNode) pushKeyFilter(filter exprInfo, table, field string) error { + var tb, col string + var err error + find := false + if _, ok := j.Left.getReferTables()[table]; ok { + for _, join := range j.joinOn { + lt := join.cols[0].Qualifier.Name.String() + lc := join.cols[0].Name.String() + if lt == table && lc == field { + tb = join.cols[1].Qualifier.Name.String() + col = join.cols[1].Name.String() + find = true + break + } + } + + if err = j.Left.pushKeyFilter(filter, table, field); err != nil { + return err + } + + if find { + // replace the colname. + origin := *(filter.cols[0]) + filter.cols[0].Name = sqlparser.NewColIdent(col) + filter.cols[0].Qualifier = sqlparser.TableName{Name: sqlparser.NewTableIdent(tb)} + if err = j.Right.pushKeyFilter(filter, tb, col); err != nil { + return err + } + // recovery the colname in exprisson. + *(filter.cols[0]) = origin + } + } else { + for _, join := range j.joinOn { + rt := join.cols[1].Qualifier.Name.String() + rc := join.cols[1].Name.String() + if rt == table && rc == field { + tb = join.cols[0].Qualifier.Name.String() + col = join.cols[0].Name.String() + find = true + break + } + } + if err = j.Right.pushKeyFilter(filter, table, field); err != nil { + return err + } + if find { + origin := *(filter.cols[0]) + filter.cols[0].Name = sqlparser.NewColIdent(col) + filter.cols[0].Qualifier = sqlparser.TableName{Name: sqlparser.NewTableIdent(tb)} + if err = j.Left.pushKeyFilter(filter, tb, col); err != nil { + return err + } + *(filter.cols[0]) = origin + } + } + return nil +} + // setParent set the parent node. func (j *JoinNode) setParent(p SelectNode) { j.parent = p @@ -344,9 +392,6 @@ func (j *JoinNode) pushEqualCmpr(joins []exprInfo) SelectNode { mn.setParent(node.parent) mn.setParenthese(node.hasParen) - for _, filter := range node.tableFilter { - mn.setWhereFilter(filter) - } for _, filter := range node.otherFilter { mn.setWhereFilter(filter) } @@ -393,12 +438,6 @@ func (j *JoinNode) pushEqualCmpr(joins []exprInfo) SelectNode { // calcRoute used to calc the route. func (j *JoinNode) calcRoute() (SelectNode, error) { var err error - for _, filter := range j.tableFilter { - if !j.buildKeyFilter(filter, false) { - tbInfo := j.referTables[filter.referTables[0]] - tbInfo.parent.setWhereFilter(filter) - } - } if j.Left, err = j.Left.calcRoute(); err != nil { return j, err } @@ -421,11 +460,6 @@ func (j *JoinNode) calcRoute() (SelectNode, error) { for _, filter := range j.otherFilter { mn.setWhereFilter(filter) } - for _, filters := range j.keyFilters { - for _, filter := range filters { - mn.setWhereFilter(filter) - } - } mn.setNoTableFilter(j.noTableFilter) if j.joinExpr == nil && len(j.joinOn) > 0 { mn.pushEqualCmpr(j.joinOn) @@ -438,68 +472,6 @@ func (j *JoinNode) calcRoute() (SelectNode, error) { return j, nil } -// buildKeyFilter used to build the keyFilter based on the tableFilter and joinOn. -// eg: select t1.a,t2.a from t1 join t2 on t1.a=t2.a where t1.a=1; -// push: select t1.a from t1 where t1.a=1 order by t1.a asc; -// select t2.a from t2 where t2.a=1 order by t2.a asc; -func (j *JoinNode) buildKeyFilter(filter exprInfo, isFind bool) bool { - table := filter.cols[0].Qualifier.Name.String() - field := filter.cols[0].Name.String() - find := false - if _, ok := j.Left.getReferTables()[filter.referTables[0]]; ok { - for i, join := range j.joinOn { - lt := join.cols[0].Qualifier.Name.String() - lc := join.cols[0].Name.String() - if lt == table && lc == field { - j.keyFilters[i] = append(j.keyFilters[i], filter) - if len(filter.vals) > 0 { - rt := join.cols[1].Qualifier.Name.String() - rc := join.cols[1].Name.String() - tbInfo := j.referTables[rt] - if tbInfo.shardKey == rc { - for _, val := range filter.vals { - if err := getIndex(j.router, tbInfo, val); err != nil { - panic(err) - } - } - } - } - find = true - break - } - } - if jn, ok := j.Left.(*JoinNode); ok { - return jn.buildKeyFilter(filter, find || isFind) - } - } else { - for i, join := range j.joinOn { - rt := join.cols[1].Qualifier.Name.String() - rc := join.cols[1].Name.String() - if rt == table && rc == field { - j.keyFilters[i] = append(j.keyFilters[i], filter) - if len(filter.vals) > 0 { - lt := join.cols[0].Qualifier.Name.String() - lc := join.cols[0].Name.String() - tbInfo := j.referTables[lt] - if tbInfo.shardKey == lc { - for _, val := range filter.vals { - if err := getIndex(j.router, tbInfo, val); err != nil { - panic(err) - } - } - } - } - find = true - break - } - } - if jn, ok := j.Right.(*JoinNode); ok { - return jn.buildKeyFilter(filter, find || isFind) - } - } - return find || isFind -} - // pushSelectExprs used to push the select fields. func (j *JoinNode) pushSelectExprs(fields, groups []selectTuple, sel *sqlparser.Select, aggTyp aggrType) error { if j.isHint { @@ -872,29 +844,9 @@ func (j *JoinNode) buildQuery(tbInfos map[string]*tableInfo) { } } - for i, filters := range j.keyFilters { - table := j.RightKeys[i].Table - field := j.RightKeys[i].Field - tbInfo := j.referTables[table] - for _, filter := range filters { - filter.cols[0].Qualifier.Name = sqlparser.NewTableIdent(table) - filter.cols[0].Name = sqlparser.NewColIdent(field) - tbInfo.parent.filters[filter.expr] = 0 - } - } j.Right.setNoTableFilter(j.noTableFilter) j.Right.buildQuery(tbInfos) - for i, filters := range j.keyFilters { - table := j.LeftKeys[i].Table - field := j.LeftKeys[i].Field - tbInfo := j.referTables[table] - for _, filter := range filters { - filter.cols[0].Qualifier.Name = sqlparser.NewTableIdent(table) - filter.cols[0].Name = sqlparser.NewColIdent(field) - tbInfo.parent.filters[filter.expr] = 0 - } - } j.Left.setNoTableFilter(j.noTableFilter) j.Left.buildQuery(tbInfos) } diff --git a/src/planner/builder/merge_node.go b/src/planner/builder/merge_node.go index 5f41fdf4..4d4d4732 100644 --- a/src/planner/builder/merge_node.go +++ b/src/planner/builder/merge_node.go @@ -49,11 +49,7 @@ type MergeNode struct { ParsedQuerys []*sqlparser.ParsedQuery // the returned result fields, used in the Multiple Plan Tree. fields []selectTuple - // filters record the filter, map struct for remove duplicate. - // eg: from t1 join t2 on t1.a=t2.a join t3 on t3.a=t2.a and t2.a=1. - // need avoid the duplicate filter `t2.a=1`. - filters map[sqlparser.Expr]int - order int + order int // Mode. ReqMode xcontext.RequestMode // aliasIndex is the tmp col's alias index. @@ -66,7 +62,6 @@ func newMergeNode(log *xlog.Log, router *router.Router) *MergeNode { log: log, router: router, referTables: make(map[string]*tableInfo), - filters: make(map[sqlparser.Expr]int), ReqMode: xcontext.ReqNormal, } } @@ -117,6 +112,21 @@ func (m *MergeNode) pushFilter(filters []exprInfo) error { return err } +func (m *MergeNode) pushKeyFilter(filter exprInfo, table, field string) error { + expr := sqlparser.CloneExpr(filter.expr) + m.addWhere(expr) + + tbInfo := m.referTables[table] + if field == tbInfo.shardKey && len(filter.vals) > 0 { + for _, val := range filter.vals { + if err := getIndex(m.router, tbInfo, val); err != nil { + return err + } + } + } + return nil +} + // setParent set the parent node. func (m *MergeNode) setParent(p SelectNode) { m.parent = p @@ -300,10 +310,6 @@ func (m *MergeNode) Order() int { func (m *MergeNode) buildQuery(tbInfos map[string]*tableInfo) { var Range string if sel, ok := m.Sel.(*sqlparser.Select); ok { - for expr := range m.filters { - m.addWhere(expr) - } - if len(sel.SelectExprs) == 0 { sel.SelectExprs = append(sel.SelectExprs, &sqlparser.AliasedExpr{ Expr: sqlparser.NewIntVal([]byte("1"))}) diff --git a/src/planner/builder/plan_node.go b/src/planner/builder/plan_node.go index d87b9d6c..eeeea1ec 100644 --- a/src/planner/builder/plan_node.go +++ b/src/planner/builder/plan_node.go @@ -29,6 +29,7 @@ type PlanNode interface { type SelectNode interface { PlanNode pushFilter(filters []exprInfo) error + pushKeyFilter(filter exprInfo, table, field string) error setParent(p SelectNode) setWhereFilter(filter exprInfo) setNoTableFilter(exprs []sqlparser.Expr) diff --git a/src/planner/select_plan_test.go b/src/planner/select_plan_test.go index e7f545c9..c7725345 100644 --- a/src/planner/select_plan_test.go +++ b/src/planner/select_plan_test.go @@ -178,7 +178,7 @@ func TestSelectPlan(t *testing.T) { "Range": "[512-4096)" }, { - "Query": "select B.id from sbtest.B1 as B where 1 = 1 and B.b = 2 and B.id = 1 order by B.id asc", + "Query": "select B.id from sbtest.B1 as B where B.id = 1 and 1 = 1 and B.b = 2 order by B.id asc", "Backend": "backend2", "Range": "[512-4096)" } @@ -198,7 +198,7 @@ func TestSelectPlan(t *testing.T) { "Range": "[512-4096)" }, { - "Query": "select /*+nested+*/ 1 from sbtest.B1 as B where :A_id = B.id and B.id = 1", + "Query": "select /*+nested+*/ 1 from sbtest.B1 as B where B.id = 1 and :A_id = B.id", "Backend": "backend2", "Range": "[512-4096)" } @@ -329,7 +329,7 @@ func TestSelectPlan(t *testing.T) { "Project": "sum(A.a), b", "Partitions": [ { - "Query": "select sum(A.a), S.b from sbtest.A1 as A join sbtest.S on A.id = S.id where A.id = 0 group by S.b", + "Query": "select sum(A.a), S.b from sbtest.A1 as A join sbtest.S on A.id = S.id where A.id = 0 and S.id = 0 group by S.b", "Backend": "backend1", "Range": "[0-32)" }