From 09715e377ccf2756279f98d42c04fbc8e105ba9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Taylor?= Date: Thu, 16 Nov 2023 10:43:54 +0100 Subject: [PATCH] Use hash joins when nested loop joins are not feasible (#14448) --- go/sqltypes/type.go | 4 + .../queries/aggregation/distinct_test.go | 8 +- .../vtgate/queries/derived/derived_test.go | 35 +- .../vtgate/queries/derived/schema.sql | 4 +- .../vtgate/queries/orderby/orderby_test.go | 4 +- .../vtgate/queries/random/simplifier_test.go | 32 +- .../queries/reference/reference_test.go | 2 +- .../vtgate/queries/subquery/subquery_test.go | 2 +- go/vt/vtgate/engine/fake_vcursor_test.go | 30 +- go/vt/vtgate/engine/hash_join.go | 266 +++++---- go/vt/vtgate/engine/hash_join_test.go | 244 ++++---- go/vt/vtgate/engine/insert_test.go | 16 +- go/vt/vtgate/engine/join.go | 2 +- go/vt/vtgate/engine/join_test.go | 12 +- go/vt/vtgate/engine/route_test.go | 112 ++-- go/vt/vtgate/engine/semi_join_test.go | 2 +- go/vt/vtgate/engine/simple_projection_test.go | 6 +- .../engine/uncorrelated_subquery_test.go | 4 +- go/vt/vtgate/evalengine/api_coerce.go | 68 +++ go/vt/vtgate/evalengine/api_hash_test.go | 10 +- go/vt/vtgate/evalengine/collation.go | 4 +- go/vt/vtgate/evalengine/compiler.go | 2 +- go/vt/vtgate/evalengine/eval.go | 2 +- go/vt/vtgate/evalengine/expr_compare.go | 2 +- go/vt/vtgate/planbuilder/join.go | 13 + .../planbuilder/operator_transformers.go | 57 ++ .../operators/aggregation_pushing.go | 20 +- .../planbuilder/operators/aggregator.go | 7 +- .../planbuilder/operators/apply_join.go | 14 +- .../vtgate/planbuilder/operators/ast_to_op.go | 2 +- .../vtgate/planbuilder/operators/distinct.go | 3 +- .../planbuilder/operators/expressions.go | 4 +- go/vt/vtgate/planbuilder/operators/filter.go | 7 +- .../vtgate/planbuilder/operators/hash_join.go | 327 +++++++++++ .../planbuilder/operators/hash_join_test.go | 100 ++++ .../operators/horizon_expanding.go | 6 +- go/vt/vtgate/planbuilder/operators/join.go | 5 +- go/vt/vtgate/planbuilder/operators/joins.go | 7 +- .../planbuilder/operators/offset_planning.go | 11 +- .../vtgate/planbuilder/operators/ordering.go | 3 +- go/vt/vtgate/planbuilder/operators/phases.go | 21 +- .../planbuilder/operators/projection.go | 13 +- .../planbuilder/operators/query_planning.go | 48 +- .../operators/rewrite/rewriters.go | 20 +- go/vt/vtgate/planbuilder/operators/route.go | 7 +- .../planbuilder/operators/route_planning.go | 22 +- .../vtgate/planbuilder/operators/subquery.go | 5 +- .../operators/subquery_planning.go | 18 +- .../planbuilder/operators/union_merging.go | 2 +- .../planbuilder/operators/utils_test.go | 90 +++ .../planbuilder/testdata/from_cases.json | 125 +++++ .../planbuilder/testdata/hash_joins.txt | 531 ------------------ .../testdata/unsupported_cases.json | 9 +- go/vt/vtgate/semantics/semantic_state.go | 10 +- go/vt/vttablet/tabletserver/rules/rules.go | 15 +- 55 files changed, 1391 insertions(+), 1004 deletions(-) create mode 100644 go/vt/vtgate/planbuilder/operators/hash_join.go create mode 100644 go/vt/vtgate/planbuilder/operators/hash_join_test.go create mode 100644 go/vt/vtgate/planbuilder/operators/utils_test.go delete mode 100644 go/vt/vtgate/planbuilder/testdata/hash_joins.txt diff --git a/go/sqltypes/type.go b/go/sqltypes/type.go index 9157db685e9..d3436ed8718 100644 --- a/go/sqltypes/type.go +++ b/go/sqltypes/type.go @@ -89,6 +89,10 @@ func IsText(t querypb.Type) bool { return int(t)&flagIsText == flagIsText } +func IsTextOrBinary(t querypb.Type) bool { + return int(t)&flagIsText == flagIsText || int(t)&flagIsBinary == flagIsBinary +} + // IsBinary returns true if querypb.Type is a binary. // If you have a Value object, use its member function. func IsBinary(t querypb.Type) bool { diff --git a/go/test/endtoend/vtgate/queries/aggregation/distinct_test.go b/go/test/endtoend/vtgate/queries/aggregation/distinct_test.go index 0a06190923c..3ec27dae6a6 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/distinct_test.go +++ b/go/test/endtoend/vtgate/queries/aggregation/distinct_test.go @@ -46,9 +46,9 @@ func TestDistinctIt(t *testing.T) { mcmp.AssertMatchesNoOrder("select distinct id from aggr_test", `[[INT64(1)] [INT64(2)] [INT64(3)] [INT64(5)] [INT64(4)] [INT64(6)] [INT64(7)] [INT64(8)]]`) if utils.BinaryIsAtLeastAtVersion(17, "vtgate") { - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ distinct val1 from aggr_test order by val1 desc", `[[VARCHAR("e")] [VARCHAR("d")] [VARCHAR("c")] [VARCHAR("b")] [VARCHAR("a")]]`) - mcmp.AssertMatchesNoOrder("select /*vt+ PLANNER=Gen4 */ distinct val1, count(*) from aggr_test group by val1", `[[VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)] [VARCHAR("d") INT64(1)] [VARCHAR("e") INT64(2)]]`) - mcmp.AssertMatchesNoOrder("select /*vt+ PLANNER=Gen4 */ distinct val1+val2 from aggr_test", `[[NULL] [FLOAT64(1)] [FLOAT64(3)] [FLOAT64(4)]]`) - mcmp.AssertMatchesNoOrder("select /*vt+ PLANNER=Gen4 */ distinct count(*) from aggr_test group by val1", `[[INT64(2)] [INT64(1)]]`) + mcmp.AssertMatches("select distinct val1 from aggr_test order by val1 desc", `[[VARCHAR("e")] [VARCHAR("d")] [VARCHAR("c")] [VARCHAR("b")] [VARCHAR("a")]]`) + mcmp.AssertMatchesNoOrder("select distinct val1, count(*) from aggr_test group by val1", `[[VARCHAR("a") INT64(2)] [VARCHAR("b") INT64(1)] [VARCHAR("c") INT64(2)] [VARCHAR("d") INT64(1)] [VARCHAR("e") INT64(2)]]`) + mcmp.AssertMatchesNoOrder("select distinct val1+val2 from aggr_test", `[[NULL] [FLOAT64(1)] [FLOAT64(3)] [FLOAT64(4)]]`) + mcmp.AssertMatchesNoOrder("select distinct count(*) from aggr_test group by val1", `[[INT64(2)] [INT64(1)]]`) } } diff --git a/go/test/endtoend/vtgate/queries/derived/derived_test.go b/go/test/endtoend/vtgate/queries/derived/derived_test.go index c3360ee4135..293dddb355c 100644 --- a/go/test/endtoend/vtgate/queries/derived/derived_test.go +++ b/go/test/endtoend/vtgate/queries/derived/derived_test.go @@ -52,7 +52,7 @@ func TestDerivedTableWithOrderByLimit(t *testing.T) { mcmp, closer := start(t) defer closer() - mcmp.Exec("select /*vt+ PLANNER=Gen4 */ music.id from music join (select id,name from user order by id limit 2) as d on music.user_id = d.id") + mcmp.Exec("select music.id from music join (select id,name from user order by id limit 2) as d on music.user_id = d.id") } func TestDerivedAggregationOnRHS(t *testing.T) { @@ -60,14 +60,14 @@ func TestDerivedAggregationOnRHS(t *testing.T) { defer closer() mcmp.Exec("set sql_mode = ''") - mcmp.Exec("select /*vt+ PLANNER=Gen4 */ d.a from music join (select id, count(*) as a from user) as d on music.user_id = d.id group by 1") + mcmp.Exec("select d.a from music join (select id, count(*) as a from user) as d on music.user_id = d.id group by 1") } func TestDerivedRemoveInnerOrderBy(t *testing.T) { mcmp, closer := start(t) defer closer() - mcmp.Exec("select /*vt+ PLANNER=Gen4 */ count(*) from (select user.id as oui, music.id as non from user join music on user.id = music.user_id order by user.name) as toto") + mcmp.Exec("select count(*) from (select user.id as oui, music.id as non from user join music on user.id = music.user_id order by user.name) as toto") } func TestDerivedTableWithHaving(t *testing.T) { @@ -76,7 +76,7 @@ func TestDerivedTableWithHaving(t *testing.T) { mcmp.Exec("set sql_mode = ''") // For the given query, we can get any id back, because we aren't grouping by it. - mcmp.AssertMatchesAnyNoCompare("select /*vt+ PLANNER=Gen4 */ * from (select id from user having count(*) >= 1) s", + mcmp.AssertMatchesAnyNoCompare("select * from (select id from user having count(*) >= 1) s", "[[INT64(1)]]", "[[INT64(2)]]", "[[INT64(3)]]", "[[INT64(4)]]", "[[INT64(5)]]") } @@ -84,6 +84,31 @@ func TestDerivedTableColumns(t *testing.T) { mcmp, closer := start(t) defer closer() - mcmp.AssertMatches(`SELECT /*vt+ PLANNER=gen4 */ t.id FROM (SELECT id FROM user) AS t(id) ORDER BY t.id DESC`, + mcmp.AssertMatches(`SELECT t.id FROM (SELECT id FROM user) AS t(id) ORDER BY t.id DESC`, `[[INT64(5)] [INT64(4)] [INT64(3)] [INT64(2)] [INT64(1)]]`) } + +// TestDerivedTablesWithLimit tests queries where we have to limit the right hand side of the join. +// We do this by not using the apply join we usually use, and instead use the hash join engine primitive +// These tests exercise these situations +func TestDerivedTablesWithLimit(t *testing.T) { + // We need full type info before planning this, so we wait for the schema tracker + require.NoError(t, + utils.WaitForAuthoritative(t, keyspaceName, "user", clusterInstance.VtgateProcess.ReadVSchema)) + + mcmp, closer := start(t) + defer closer() + + mcmp.Exec("insert into user(id, name) values(6,'pikachu')") + + mcmp.AssertMatchesNoOrder( + `SELECT u.id, m.id FROM + (SELECT id, name FROM user LIMIT 10) AS u JOIN + (SELECT id, user_id FROM music LIMIT 10) as m on u.id = m.user_id`, + `[[INT64(1) INT64(1)] [INT64(5) INT64(2)] [INT64(1) INT64(3)] [INT64(2) INT64(4)] [INT64(3) INT64(5)] [INT64(5) INT64(7)] [INT64(4) INT64(6)]]`) + + mcmp.AssertMatchesNoOrder( + `SELECT u.id, m.id FROM user AS u LEFT JOIN + (SELECT id, user_id FROM music LIMIT 10) as m on u.id = m.user_id`, + `[[INT64(1) INT64(1)] [INT64(5) INT64(2)] [INT64(1) INT64(3)] [INT64(2) INT64(4)] [INT64(3) INT64(5)] [INT64(5) INT64(7)] [INT64(4) INT64(6)] [INT64(6) NULL]]`) +} diff --git a/go/test/endtoend/vtgate/queries/derived/schema.sql b/go/test/endtoend/vtgate/queries/derived/schema.sql index cf608028ed5..3cb8619d93b 100644 --- a/go/test/endtoend/vtgate/queries/derived/schema.sql +++ b/go/test/endtoend/vtgate/queries/derived/schema.sql @@ -1,13 +1,13 @@ create table user ( - id bigint, + id bigint, name varchar(255), primary key (id) ) Engine = InnoDB; create table music ( - id bigint, + id bigint, user_id bigint, primary key (id) ) Engine = InnoDB; diff --git a/go/test/endtoend/vtgate/queries/orderby/orderby_test.go b/go/test/endtoend/vtgate/queries/orderby/orderby_test.go index dd48a09fec7..43f800ee24c 100644 --- a/go/test/endtoend/vtgate/queries/orderby/orderby_test.go +++ b/go/test/endtoend/vtgate/queries/orderby/orderby_test.go @@ -68,7 +68,7 @@ func TestOrderBy(t *testing.T) { mcmp.AssertMatches("select id1, id2 from t4 order by id1 desc", `[[INT64(8) VARCHAR("F")] [INT64(7) VARCHAR("e")] [INT64(6) VARCHAR("d")] [INT64(5) VARCHAR("test")] [INT64(4) VARCHAR("c")] [INT64(3) VARCHAR("b")] [INT64(2) VARCHAR("Abc")] [INT64(1) VARCHAR("a")]]`) // test ordering of complex column if utils.BinaryIsAtLeastAtVersion(17, "vtgate") { - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ id1, id2 from t4 order by reverse(id2) desc", `[[INT64(5) VARCHAR("test")] [INT64(8) VARCHAR("F")] [INT64(7) VARCHAR("e")] [INT64(6) VARCHAR("d")] [INT64(2) VARCHAR("Abc")] [INT64(4) VARCHAR("c")] [INT64(3) VARCHAR("b")] [INT64(1) VARCHAR("a")]]`) + mcmp.AssertMatches("select id1, id2 from t4 order by reverse(id2) desc", `[[INT64(5) VARCHAR("test")] [INT64(8) VARCHAR("F")] [INT64(7) VARCHAR("e")] [INT64(6) VARCHAR("d")] [INT64(2) VARCHAR("Abc")] [INT64(4) VARCHAR("c")] [INT64(3) VARCHAR("b")] [INT64(1) VARCHAR("a")]]`) } defer func() { @@ -80,6 +80,6 @@ func TestOrderBy(t *testing.T) { mcmp.AssertMatches("select id1, id2 from t4 order by id2 desc", `[[INT64(5) VARCHAR("test")] [INT64(8) VARCHAR("F")] [INT64(7) VARCHAR("e")] [INT64(6) VARCHAR("d")] [INT64(4) VARCHAR("c")] [INT64(3) VARCHAR("b")] [INT64(2) VARCHAR("Abc")] [INT64(1) VARCHAR("a")]]`) mcmp.AssertMatches("select id1, id2 from t4 order by id1 desc", `[[INT64(8) VARCHAR("F")] [INT64(7) VARCHAR("e")] [INT64(6) VARCHAR("d")] [INT64(5) VARCHAR("test")] [INT64(4) VARCHAR("c")] [INT64(3) VARCHAR("b")] [INT64(2) VARCHAR("Abc")] [INT64(1) VARCHAR("a")]]`) if utils.BinaryIsAtLeastAtVersion(17, "vtgate") { - mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ id1, id2 from t4 order by reverse(id2) desc", `[[INT64(5) VARCHAR("test")] [INT64(8) VARCHAR("F")] [INT64(7) VARCHAR("e")] [INT64(6) VARCHAR("d")] [INT64(2) VARCHAR("Abc")] [INT64(4) VARCHAR("c")] [INT64(3) VARCHAR("b")] [INT64(1) VARCHAR("a")]]`) + mcmp.AssertMatches("select id1, id2 from t4 order by reverse(id2) desc", `[[INT64(5) VARCHAR("test")] [INT64(8) VARCHAR("F")] [INT64(7) VARCHAR("e")] [INT64(6) VARCHAR("d")] [INT64(2) VARCHAR("Abc")] [INT64(4) VARCHAR("c")] [INT64(3) VARCHAR("b")] [INT64(1) VARCHAR("a")]]`) } } diff --git a/go/test/endtoend/vtgate/queries/random/simplifier_test.go b/go/test/endtoend/vtgate/queries/random/simplifier_test.go index 478ee355d34..2be9ef8ab93 100644 --- a/go/test/endtoend/vtgate/queries/random/simplifier_test.go +++ b/go/test/endtoend/vtgate/queries/random/simplifier_test.go @@ -36,22 +36,22 @@ func TestSimplifyResultsMismatchedQuery(t *testing.T) { t.Skip("Skip CI") var queries []string - queries = append(queries, "select /*vt+ PLANNER=Gen4 */ (68 - -16) / case false when -45 then 3 when 28 then -43 else -62 end as crandom0 from dept as tbl0, (select /*vt+ PLANNER=Gen4 */ distinct not not false and count(*) from emp as tbl0, emp as tbl1 where tbl1.ename) as tbl1 limit 1", - "select /*vt+ PLANNER=Gen4 */ distinct case true when 'burro' then 'trout' else 'elf' end < case count(distinct true) when 'bobcat' then 'turkey' else 'penguin' end from dept as tbl0, emp as tbl1 where 'spider'", - "select /*vt+ PLANNER=Gen4 */ distinct sum(distinct tbl1.deptno) from dept as tbl0, emp as tbl1 where tbl0.deptno and tbl1.comm in (12, tbl0.deptno, case false when 67 then -17 when -78 then -35 end, -76 >> -68)", - "select /*vt+ PLANNER=Gen4 */ count(*) + 1 from emp as tbl0 order by count(*) desc", - "select /*vt+ PLANNER=Gen4 */ count(2 >> tbl2.mgr), sum(distinct tbl2.empno <=> 15) from emp as tbl0 left join emp as tbl2 on -32", - "select /*vt+ PLANNER=Gen4 */ sum(case false when true then tbl1.deptno else -154 / 132 end) as caggr1 from emp as tbl0, dept as tbl1", - "select /*vt+ PLANNER=Gen4 */ tbl1.dname as cgroup0, tbl1.dname as cgroup1 from dept as tbl0, dept as tbl1 group by tbl1.dname, tbl1.deptno order by tbl1.deptno desc", - "select /*vt+ PLANNER=Gen4 */ tbl0.ename as cgroup1 from emp as tbl0 group by tbl0.job, tbl0.ename having sum(tbl0.mgr) = sum(tbl0.mgr) order by tbl0.job desc, tbl0.ename asc limit 8", - "select /*vt+ PLANNER=Gen4 */ distinct count(*) as caggr1 from dept as tbl0, emp as tbl1 group by tbl1.sal having max(tbl1.comm) != true", - "select /*vt+ PLANNER=Gen4 */ distinct sum(tbl1.loc) as caggr0 from dept as tbl0, dept as tbl1 group by tbl1.deptno having max(tbl1.dname) <= 1", - "select /*vt+ PLANNER=Gen4 */ min(tbl0.deptno) as caggr0 from dept as tbl0, emp as tbl1 where case when false then tbl0.dname end group by tbl1.comm", - "select /*vt+ PLANNER=Gen4 */ count(*) as caggr0, 1 as crandom0 from dept as tbl0, emp as tbl1 where 1 = 0", - "select /*vt+ PLANNER=Gen4 */ count(*) as caggr0, 1 as crandom0 from dept as tbl0, emp as tbl1 where 'octopus'", - "select /*vt+ PLANNER=Gen4 */ distinct 'octopus' as crandom0 from dept as tbl0, emp as tbl1 where tbl0.deptno = tbl1.empno having count(*) = count(*)", - "select /*vt+ PLANNER=Gen4 */ max(tbl0.deptno) from dept as tbl0 right join emp as tbl1 on tbl0.deptno = tbl1.empno and tbl0.deptno = tbl1.deptno group by tbl0.deptno", - "select /*vt+ PLANNER=Gen4 */ count(tbl1.comm) from emp as tbl1 right join emp as tbl2 on tbl1.mgr = tbl2.sal") + queries = append(queries, "select (68 - -16) / case false when -45 then 3 when 28 then -43 else -62 end as crandom0 from dept as tbl0, (select distinct not not false and count(*) from emp as tbl0, emp as tbl1 where tbl1.ename) as tbl1 limit 1", + "select distinct case true when 'burro' then 'trout' else 'elf' end < case count(distinct true) when 'bobcat' then 'turkey' else 'penguin' end from dept as tbl0, emp as tbl1 where 'spider'", + "select distinct sum(distinct tbl1.deptno) from dept as tbl0, emp as tbl1 where tbl0.deptno and tbl1.comm in (12, tbl0.deptno, case false when 67 then -17 when -78 then -35 end, -76 >> -68)", + "select count(*) + 1 from emp as tbl0 order by count(*) desc", + "select count(2 >> tbl2.mgr), sum(distinct tbl2.empno <=> 15) from emp as tbl0 left join emp as tbl2 on -32", + "select sum(case false when true then tbl1.deptno else -154 / 132 end) as caggr1 from emp as tbl0, dept as tbl1", + "select tbl1.dname as cgroup0, tbl1.dname as cgroup1 from dept as tbl0, dept as tbl1 group by tbl1.dname, tbl1.deptno order by tbl1.deptno desc", + "select tbl0.ename as cgroup1 from emp as tbl0 group by tbl0.job, tbl0.ename having sum(tbl0.mgr) = sum(tbl0.mgr) order by tbl0.job desc, tbl0.ename asc limit 8", + "select distinct count(*) as caggr1 from dept as tbl0, emp as tbl1 group by tbl1.sal having max(tbl1.comm) != true", + "select distinct sum(tbl1.loc) as caggr0 from dept as tbl0, dept as tbl1 group by tbl1.deptno having max(tbl1.dname) <= 1", + "select min(tbl0.deptno) as caggr0 from dept as tbl0, emp as tbl1 where case when false then tbl0.dname end group by tbl1.comm", + "select count(*) as caggr0, 1 as crandom0 from dept as tbl0, emp as tbl1 where 1 = 0", + "select count(*) as caggr0, 1 as crandom0 from dept as tbl0, emp as tbl1 where 'octopus'", + "select distinct 'octopus' as crandom0 from dept as tbl0, emp as tbl1 where tbl0.deptno = tbl1.empno having count(*) = count(*)", + "select max(tbl0.deptno) from dept as tbl0 right join emp as tbl1 on tbl0.deptno = tbl1.empno and tbl0.deptno = tbl1.deptno group by tbl0.deptno", + "select count(tbl1.comm) from emp as tbl1 right join emp as tbl2 on tbl1.mgr = tbl2.sal") for _, query := range queries { var simplified string diff --git a/go/test/endtoend/vtgate/queries/reference/reference_test.go b/go/test/endtoend/vtgate/queries/reference/reference_test.go index 75efc840880..9c2508a889f 100644 --- a/go/test/endtoend/vtgate/queries/reference/reference_test.go +++ b/go/test/endtoend/vtgate/queries/reference/reference_test.go @@ -114,7 +114,7 @@ func TestReferenceRouting(t *testing.T) { utils.AssertMatches( t, conn, - `SELECT /*vt+ PLANNER=gen4 */ COUNT(zd.id) + `SELECT COUNT(zd.id) FROM delivery_failure df JOIN zip_detail zd ON zd.id = df.zip_detail_id WHERE zd.id = 3`, `[[INT64(0)]]`, diff --git a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go index 01cc7b2ee54..371de19dbe6 100644 --- a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go +++ b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go @@ -106,7 +106,7 @@ func TestSubqueryInUpdate(t *testing.T) { utils.Exec(t, conn, `insert into t1(id1, id2) values (1, 10), (2, 20), (3, 30), (4, 40), (5, 50)`) utils.Exec(t, conn, `insert into t2(id3, id4) values (1, 3), (2, 4)`) utils.AssertMatches(t, conn, `SELECT id2, keyspace_id FROM t1_id2_idx WHERE id2 IN (2,10)`, `[[INT64(10) VARBINARY("\x16k@\xb4J\xbaK\xd6")]]`) - utils.Exec(t, conn, `update /*vt+ PLANNER=gen4 */ t1 set id2 = (select count(*) from t2) where id1 = 1`) + utils.Exec(t, conn, `update t1 set id2 = (select count(*) from t2) where id1 = 1`) utils.AssertMatches(t, conn, `SELECT id2 FROM t1 WHERE id1 = 1`, `[[INT64(2)]]`) utils.AssertMatches(t, conn, `SELECT id2, keyspace_id FROM t1_id2_idx WHERE id2 IN (2,10)`, `[[INT64(2) VARBINARY("\x16k@\xb4J\xbaK\xd6")]]`) } diff --git a/go/vt/vtgate/engine/fake_vcursor_test.go b/go/vt/vtgate/engine/fake_vcursor_test.go index 6c99af33313..9776572f1a5 100644 --- a/go/vt/vtgate/engine/fake_vcursor_test.go +++ b/go/vt/vtgate/engine/fake_vcursor_test.go @@ -21,12 +21,15 @@ import ( "context" "fmt" "reflect" + "slices" "sort" "strings" "sync" "testing" "time" + "github.com/google/go-cmp/cmp" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/test/utils" @@ -806,18 +809,39 @@ func (t *noopVCursor) GetLogs() ([]ExecuteEntry, error) { return nil, nil } -func expectResult(t *testing.T, msg string, result, want *sqltypes.Result) { +func expectResult(t *testing.T, result, want *sqltypes.Result) { t.Helper() fieldsResult := fmt.Sprintf("%v", result.Fields) fieldsWant := fmt.Sprintf("%v", want.Fields) if fieldsResult != fieldsWant { - t.Errorf("%s (mismatch in Fields):\n%s\nwant:\n%s", msg, fieldsResult, fieldsWant) + t.Errorf("mismatch in Fields\n%s\nwant:\n%s", fieldsResult, fieldsWant) } rowsResult := fmt.Sprintf("%v", result.Rows) rowsWant := fmt.Sprintf("%v", want.Rows) if rowsResult != rowsWant { - t.Errorf("%s (mismatch in Rows):\n%s\nwant:\n%s", msg, rowsResult, rowsWant) + t.Errorf("mismatch in Rows:\n%s\nwant:\n%s", rowsResult, rowsWant) + } +} + +func expectResultAnyOrder(t *testing.T, result, want *sqltypes.Result) { + t.Helper() + f := func(a, b sqltypes.Row) int { + for i := range a { + l := a[i].RawStr() + r := b[i].RawStr() + x := strings.Compare(l, r) + if x == 0 { + continue + } + return x + } + return 0 + } + slices.SortFunc(result.Rows, f) + slices.SortFunc(want.Rows, f) + if diff := cmp.Diff(want, result); diff != "" { + t.Errorf("result: %+v, want %+v\ndiff: %s", result, want, diff) } } diff --git a/go/vt/vtgate/engine/hash_join.go b/go/vt/vtgate/engine/hash_join.go index a38fc21bf97..4f205f1bcdc 100644 --- a/go/vt/vtgate/engine/hash_join.go +++ b/go/vt/vtgate/engine/hash_join.go @@ -20,48 +20,69 @@ import ( "context" "fmt" "strings" + "sync" + "sync/atomic" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/vthash" ) var _ Primitive = (*HashJoin)(nil) -// HashJoin specifies the parameters for a join primitive -// Hash joins work by fetch all the input from the LHS, and building a hash map, known as the probe table, for this input. -// The key to the map is the hashcode of the value for column that we are joining by. -// Then the RHS is fetched, and we can check if the rows from the RHS matches any from the LHS. -// When they match by hash code, we double-check that we are not working with a false positive by comparing the values. -type HashJoin struct { - Opcode JoinOpcode - - // Left and Right are the LHS and RHS primitives - // of the Join. They can be any primitive. - Left, Right Primitive `json:",omitempty"` - - // Cols defines which columns from the left - // or right results should be used to build the - // return result. For results coming from the - // left query, the index values go as -1, -2, etc. - // For the right query, they're 1, 2, etc. - // If Cols is {-1, -2, 1, 2}, it means that - // the returned result will be {Left0, Left1, Right0, Right1}. - Cols []int `json:",omitempty"` - - // The keys correspond to the column offset in the inputs where - // the join columns can be found - LHSKey, RHSKey int - - // The join condition. Used for plan descriptions - ASTPred sqlparser.Expr - - // collation and type are used to hash the incoming values correctly - Collation collations.ID - ComparisonType querypb.Type -} +type ( + // HashJoin specifies the parameters for a join primitive + // Hash joins work by fetch all the input from the LHS, and building a hash map, known as the probe table, for this input. + // The key to the map is the hashcode of the value for column that we are joining by. + // Then the RHS is fetched, and we can check if the rows from the RHS matches any from the LHS. + // When they match by hash code, we double-check that we are not working with a false positive by comparing the values. + HashJoin struct { + Opcode JoinOpcode + + // Left and Right are the LHS and RHS primitives + // of the Join. They can be any primitive. + Left, Right Primitive `json:",omitempty"` + + // Cols defines which columns from the left + // or right results should be used to build the + // return result. For results coming from the + // left query, the index values go as -1, -2, etc. + // For the right query, they're 1, 2, etc. + // If Cols is {-1, -2, 1, 2}, it means that + // the returned result will be {Left0, Left1, Right0, Right1}. + Cols []int `json:",omitempty"` + + // The keys correspond to the column offset in the inputs where + // the join columns can be found + LHSKey, RHSKey int + + // The join condition. Used for plan descriptions + ASTPred sqlparser.Expr + + // collation and type are used to hash the incoming values correctly + Collation collations.ID + ComparisonType querypb.Type + } + + hashJoinProbeTable struct { + innerMap map[vthash.Hash]*probeTableEntry + + coll collations.ID + typ querypb.Type + lhsKey, rhsKey int + cols []int + hasher vthash.Hasher + } + + probeTableEntry struct { + row sqltypes.Row + next *probeTableEntry + seen bool + } +) // TryExecute implements the Primitive interface func (hj *HashJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { @@ -70,10 +91,13 @@ func (hj *HashJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma return nil, err } + pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols) // build the probe table from the LHS result - probeTable, err := hj.buildProbeTable(lresult) - if err != nil { - return nil, err + for _, row := range lresult.Rows { + err := pt.addLeftRow(row) + if err != nil { + return nil, err + } } rresult, err := vcursor.ExecutePrimitive(ctx, hj.Right, bindVars, wantfields) @@ -86,68 +110,37 @@ func (hj *HashJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma } for _, currentRHSRow := range rresult.Rows { - joinVal := currentRHSRow[hj.RHSKey] - if joinVal.IsNull() { - continue - } - hashcode, err := evalengine.NullsafeHashcode(joinVal, hj.Collation, hj.ComparisonType) + matches, err := pt.get(currentRHSRow) if err != nil { return nil, err } - lftRows := probeTable[hashcode] - for _, currentLHSRow := range lftRows { - lhsVal := currentLHSRow[hj.LHSKey] - // hash codes can give false positives, so we need to check with a real comparison as well - cmp, err := evalengine.NullsafeCompare(joinVal, lhsVal, hj.Collation) - if err != nil { - return nil, err - } + result.Rows = append(result.Rows, matches...) + } - if cmp == 0 { - // we have a match! - result.Rows = append(result.Rows, joinRows(currentLHSRow, currentRHSRow, hj.Cols)) - } - } + if hj.Opcode == LeftJoin { + result.Rows = append(result.Rows, pt.notFetched()...) } return result, nil } -func (hj *HashJoin) buildProbeTable(lresult *sqltypes.Result) (map[evalengine.HashCode][]sqltypes.Row, error) { - probeTable := map[evalengine.HashCode][]sqltypes.Row{} - for _, current := range lresult.Rows { - joinVal := current[hj.LHSKey] - if joinVal.IsNull() { - continue - } - hashcode, err := evalengine.NullsafeHashcode(joinVal, hj.Collation, hj.ComparisonType) - if err != nil { - return nil, err - } - probeTable[hashcode] = append(probeTable[hashcode], current) - } - return probeTable, nil -} - // TryStreamExecute implements the Primitive interface func (hj *HashJoin) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { // build the probe table from the LHS result - probeTable := map[evalengine.HashCode][]sqltypes.Row{} + pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols) var lfields []*querypb.Field + var mu sync.Mutex err := vcursor.StreamExecutePrimitive(ctx, hj.Left, bindVars, wantfields, func(result *sqltypes.Result) error { + mu.Lock() + defer mu.Unlock() if len(lfields) == 0 && len(result.Fields) != 0 { lfields = result.Fields } for _, current := range result.Rows { - joinVal := current[hj.LHSKey] - if joinVal.IsNull() { - continue - } - hashcode, err := evalengine.NullsafeHashcode(joinVal, hj.Collation, hj.ComparisonType) + err := pt.addLeftRow(current) if err != nil { return err } - probeTable[hashcode] = append(probeTable[hashcode], current) } return nil }) @@ -155,43 +148,50 @@ func (hj *HashJoin) TryStreamExecute(ctx context.Context, vcursor VCursor, bindV return err } - return vcursor.StreamExecutePrimitive(ctx, hj.Right, bindVars, wantfields, func(result *sqltypes.Result) error { + var sendFields atomic.Bool + sendFields.Store(wantfields) + + err = vcursor.StreamExecutePrimitive(ctx, hj.Right, bindVars, sendFields.Load(), func(result *sqltypes.Result) error { + mu.Lock() + defer mu.Unlock() // compare the results coming from the RHS with the probe-table res := &sqltypes.Result{} - if len(result.Fields) != 0 { - res = &sqltypes.Result{ - Fields: joinFields(lfields, result.Fields, hj.Cols), - } + if len(result.Fields) != 0 && sendFields.CompareAndSwap(true, false) { + res.Fields = joinFields(lfields, result.Fields, hj.Cols) } for _, currentRHSRow := range result.Rows { - joinVal := currentRHSRow[hj.RHSKey] - if joinVal.IsNull() { - continue - } - hashcode, err := evalengine.NullsafeHashcode(joinVal, hj.Collation, hj.ComparisonType) + results, err := pt.get(currentRHSRow) if err != nil { return err } - lftRows := probeTable[hashcode] - for _, currentLHSRow := range lftRows { - lhsVal := currentLHSRow[hj.LHSKey] - // hash codes can give false positives, so we need to check with a real comparison as well - cmp, err := evalengine.NullsafeCompare(joinVal, lhsVal, hj.Collation) - if err != nil { - return err - } - - if cmp == 0 { - // we have a match! - res.Rows = append(res.Rows, joinRows(currentLHSRow, currentRHSRow, hj.Cols)) - } - } + res.Rows = append(res.Rows, results...) } if len(res.Rows) != 0 || len(res.Fields) != 0 { return callback(res) } return nil }) + if err != nil { + return err + } + + if hj.Opcode == LeftJoin { + res := &sqltypes.Result{} + if sendFields.CompareAndSwap(true, false) { + // If we still have not sent the fields, we need to fetch + // the fields from the RHS to be able to build the result fields + rres, err := hj.Right.GetFields(ctx, vcursor, bindVars) + if err != nil { + return err + } + res.Fields = joinFields(lfields, rres.Fields, hj.Cols) + } + // this will only be called when all the concurrent access to the pt has + // ceased, so we don't need to lock it here + res.Rows = pt.notFetched() + return callback(res) + } + return nil } // RouteType implements the Primitive interface @@ -256,3 +256,69 @@ func (hj *HashJoin) description() PrimitiveDescription { Other: other, } } + +func newHashJoinProbeTable(coll collations.ID, typ querypb.Type, lhsKey, rhsKey int, cols []int) *hashJoinProbeTable { + return &hashJoinProbeTable{ + innerMap: map[vthash.Hash]*probeTableEntry{}, + coll: coll, + typ: typ, + lhsKey: lhsKey, + rhsKey: rhsKey, + cols: cols, + hasher: vthash.New(), + } +} + +func (pt *hashJoinProbeTable) addLeftRow(r sqltypes.Row) error { + hash, err := pt.hash(r[pt.lhsKey]) + if err != nil { + return err + } + pt.innerMap[hash] = &probeTableEntry{ + row: r, + next: pt.innerMap[hash], + } + + return nil +} + +func (pt *hashJoinProbeTable) hash(val sqltypes.Value) (vthash.Hash, error) { + err := evalengine.NullsafeHashcode128(&pt.hasher, val, pt.coll, pt.typ) + if err != nil { + return vthash.Hash{}, err + } + + res := pt.hasher.Sum128() + pt.hasher.Reset() + return res, nil +} + +func (pt *hashJoinProbeTable) get(rrow sqltypes.Row) (result []sqltypes.Row, err error) { + val := rrow[pt.rhsKey] + if val.IsNull() { + return + } + + hash, err := pt.hash(val) + if err != nil { + return nil, err + } + + for e := pt.innerMap[hash]; e != nil; e = e.next { + e.seen = true + result = append(result, joinRows(e.row, rrow, pt.cols)) + } + + return +} + +func (pt *hashJoinProbeTable) notFetched() (rows []sqltypes.Row) { + for _, e := range pt.innerMap { + for ; e != nil; e = e.next { + if !e.seen { + rows = append(rows, joinRows(e.row, nil, pt.cols)) + } + } + } + return +} diff --git a/go/vt/vtgate/engine/hash_join_test.go b/go/vt/vtgate/engine/hash_join_test.go index 8add0b78fa2..c0a3075c6ad 100644 --- a/go/vt/vtgate/engine/hash_join_test.go +++ b/go/vt/vtgate/engine/hash_join_test.go @@ -22,126 +22,152 @@ import ( "github.com/stretchr/testify/require" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/vtgate/evalengine" ) -func TestHashJoinExecuteSameType(t *testing.T) { - leftPrim := &fakePrimitive{ - results: []*sqltypes.Result{ - sqltypes.MakeTestResult( - sqltypes.MakeTestFields( - "col1|col2|col3", - "int64|varchar|varchar", +func TestHashJoinVariations(t *testing.T) { + // This test tries the different variations of hash-joins: + // comparing values of same type and different types, and both left and right outer joins + lhs := func() Primitive { + return &fakePrimitive{ + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2", + "int64|varchar", + ), + "1|1", + "2|2", + "3|b", + "null|b", ), - "1|a|aa", - "2|b|bb", - "3|c|cc", - ), - }, + }, + } } - rightPrim := &fakePrimitive{ - results: []*sqltypes.Result{ - sqltypes.MakeTestResult( - sqltypes.MakeTestFields( - "col4|col5|col6", - "int64|varchar|varchar", + rhs := func() Primitive { + return &fakePrimitive{ + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col4|col5", + "int64|varchar", + ), + "1|1", + "3|2", + "5|null", + "4|b", ), - "1|d|dd", - "3|e|ee", - "4|f|ff", - "3|g|gg", - ), - }, + }, + } } - // Normal join - jn := &HashJoin{ - Opcode: InnerJoin, - Left: leftPrim, - Right: rightPrim, - Cols: []int{-1, -2, 1, 2}, - LHSKey: 0, - RHSKey: 0, - } - r, err := jn.TryExecute(context.Background(), &noopVCursor{}, map[string]*querypb.BindVariable{}, true) - require.NoError(t, err) - leftPrim.ExpectLog(t, []string{ - `Execute true`, - }) - rightPrim.ExpectLog(t, []string{ - `Execute true`, - }) - expectResult(t, "jn.Execute", r, sqltypes.MakeTestResult( - sqltypes.MakeTestFields( - "col1|col2|col4|col5", - "int64|varchar|int64|varchar", - ), - "1|a|1|d", - "3|c|3|e", - "3|c|3|g", - )) -} + rows := func(r ...string) []string { return r } -func TestHashJoinExecuteDifferentType(t *testing.T) { - leftPrim := &fakePrimitive{ - results: []*sqltypes.Result{ - sqltypes.MakeTestResult( - sqltypes.MakeTestFields( - "col1|col2|col3", - "int64|varchar|varchar", - ), - "1|a|aa", - "2|b|bb", - "3|c|cc", - "5|c|cc", - ), - }, - } - rightPrim := &fakePrimitive{ - results: []*sqltypes.Result{ - sqltypes.MakeTestResult( - sqltypes.MakeTestFields( - "col4|col5|col6", - "varchar|varchar|varchar", - ), - "1.00|d|dd", - "3|e|ee", - "2.89|z|zz", - "4|f|ff", - "3|g|gg", - " 5.0toto|g|gg", - "w|ww|www", - ), - }, + tests := []struct { + name string + typ JoinOpcode + lhs, rhs int + expected []string + reverse bool + }{{ + name: "inner join, same type", + typ: InnerJoin, + lhs: 0, + rhs: 0, + expected: rows("1|1|1|1", "3|b|3|2"), + }, { + name: "inner join, coercion", + typ: InnerJoin, + lhs: 0, + rhs: 1, + expected: rows("1|1|1|1", "2|2|3|2"), + }, { + name: "left join, same type", + typ: LeftJoin, + lhs: 0, + rhs: 0, + expected: rows("1|1|1|1", "3|b|3|2", "2|2|null|null", "null|b|null|null"), + }, { + name: "left join, coercion", + typ: LeftJoin, + lhs: 0, + rhs: 1, + expected: rows("1|1|1|1", "2|2|3|2", "3|b|null|null", "null|b|null|null"), + }, { + name: "right join, same type", + typ: LeftJoin, + lhs: 0, + rhs: 0, + expected: rows("1|1|1|1", "3|2|3|b", "4|b|null|null", "5|null|null|null"), + reverse: true, + }, { + name: "right join, coercion", + typ: LeftJoin, + lhs: 0, + rhs: 1, + reverse: true, + expected: rows("1|1|1|1", "3|2|null|null", "4|b|null|null", "5|null|null|null"), + }} + + for _, tc := range tests { + + var fields []*querypb.Field + var first, last func() Primitive + if tc.reverse { + first, last = rhs, lhs + fields = sqltypes.MakeTestFields( + "col4|col5|col1|col2", + "int64|varchar|int64|varchar", + ) + } else { + first, last = lhs, rhs + fields = sqltypes.MakeTestFields( + "col1|col2|col4|col5", + "int64|varchar|int64|varchar", + ) + } + + expected := sqltypes.MakeTestResult(fields, tc.expected...) + + typ, err := evalengine.CoerceTypes(typeForOffset(tc.lhs), typeForOffset(tc.rhs)) + require.NoError(t, err) + + jn := &HashJoin{ + Opcode: tc.typ, + Cols: []int{-1, -2, 1, 2}, + LHSKey: tc.lhs, + RHSKey: tc.rhs, + Collation: typ.Coll, + ComparisonType: typ.Type, + } + + t.Run(tc.name, func(t *testing.T) { + jn.Left = first() + jn.Right = last() + r, err := jn.TryExecute(context.Background(), &noopVCursor{}, map[string]*querypb.BindVariable{}, true) + require.NoError(t, err) + expectResultAnyOrder(t, r, expected) + }) + t.Run("Streaming "+tc.name, func(t *testing.T) { + jn.Left = first() + jn.Right = last() + r, err := wrapStreamExecute(jn, &noopVCursor{}, map[string]*querypb.BindVariable{}, true) + require.NoError(t, err) + expectResultAnyOrder(t, r, expected) + }) } +} - // Normal join - jn := &HashJoin{ - Opcode: InnerJoin, - Left: leftPrim, - Right: rightPrim, - Cols: []int{-1, -2, 1, 2}, - LHSKey: 0, - RHSKey: 0, - ComparisonType: querypb.Type_FLOAT64, +func typeForOffset(i int) evalengine.Type { + switch i { + case 0: + return evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID} + case 1: + return evalengine.Type{Type: sqltypes.VarChar, Coll: collations.Default()} + default: + panic(i) } - r, err := jn.TryExecute(context.Background(), &noopVCursor{}, map[string]*querypb.BindVariable{}, true) - require.NoError(t, err) - leftPrim.ExpectLog(t, []string{ - `Execute true`, - }) - rightPrim.ExpectLog(t, []string{ - `Execute true`, - }) - expectResult(t, "jn.Execute", r, sqltypes.MakeTestResult( - sqltypes.MakeTestFields( - "col1|col2|col4|col5", - "int64|varchar|varchar|varchar", - ), - "1|a|1.00|d", - "3|c|3|e", - "3|c|3|g", - "5|c| 5.0toto|g", - )) } diff --git a/go/vt/vtgate/engine/insert_test.go b/go/vt/vtgate/engine/insert_test.go index 014654f37d6..d08ef856275 100644 --- a/go/vt/vtgate/engine/insert_test.go +++ b/go/vt/vtgate/engine/insert_test.go @@ -55,7 +55,7 @@ func TestInsertUnsharded(t *testing.T) { `ResolveDestinations ks [] Destinations:DestinationAllShards()`, `ExecuteMultiShard ks.0: dummy_insert {} true true`, }) - expectResult(t, "Execute", result, &sqltypes.Result{InsertID: 4}) + expectResult(t, result, &sqltypes.Result{InsertID: 4}) // Failure cases vc = &loggingVCursor{shardErr: errors.New("shard_error")} @@ -117,7 +117,7 @@ func TestInsertUnshardedGenerate(t *testing.T) { }) // The insert id returned by ExecuteMultiShard should be overwritten by processGenerateFromValues. - expectResult(t, "Execute", result, &sqltypes.Result{InsertID: 4}) + expectResult(t, result, &sqltypes.Result{InsertID: 4}) } func TestInsertUnshardedGenerate_Zeros(t *testing.T) { @@ -170,7 +170,7 @@ func TestInsertUnshardedGenerate_Zeros(t *testing.T) { }) // The insert id returned by ExecuteMultiShard should be overwritten by processGenerateFromValues. - expectResult(t, "Execute", result, &sqltypes.Result{InsertID: 4}) + expectResult(t, result, &sqltypes.Result{InsertID: 4}) } func TestInsertShardedSimple(t *testing.T) { @@ -462,7 +462,7 @@ func TestInsertShardedGenerate(t *testing.T) { }) // The insert id returned by ExecuteMultiShard should be overwritten by processGenerateFromValues. - expectResult(t, "Execute", result, &sqltypes.Result{InsertID: 2}) + expectResult(t, result, &sqltypes.Result{InsertID: 2}) } func TestInsertShardedOwned(t *testing.T) { @@ -1797,7 +1797,7 @@ func TestInsertSelectGenerate(t *testing.T) { }) // The insert id returned by ExecuteMultiShard should be overwritten by processGenerateFromValues. - expectResult(t, "Execute", result, &sqltypes.Result{InsertID: 2}) + expectResult(t, result, &sqltypes.Result{InsertID: 2}) } func TestStreamingInsertSelectGenerate(t *testing.T) { @@ -1893,7 +1893,7 @@ func TestStreamingInsertSelectGenerate(t *testing.T) { }) // The insert id returned by ExecuteMultiShard should be overwritten by processGenerateFromValues. - expectResult(t, "Execute", output, &sqltypes.Result{InsertID: 2}) + expectResult(t, output, &sqltypes.Result{InsertID: 2}) } func TestInsertSelectGenerateNotProvided(t *testing.T) { @@ -1981,7 +1981,7 @@ func TestInsertSelectGenerateNotProvided(t *testing.T) { }) // The insert id returned by ExecuteMultiShard should be overwritten by processGenerateFromValues. - expectResult(t, "Execute", result, &sqltypes.Result{InsertID: 10}) + expectResult(t, result, &sqltypes.Result{InsertID: 10}) } func TestStreamingInsertSelectGenerateNotProvided(t *testing.T) { @@ -2073,7 +2073,7 @@ func TestStreamingInsertSelectGenerateNotProvided(t *testing.T) { }) // The insert id returned by ExecuteMultiShard should be overwritten by processGenerateFromValues. - expectResult(t, "Execute", output, &sqltypes.Result{InsertID: 10}) + expectResult(t, output, &sqltypes.Result{InsertID: 10}) } func TestInsertSelectUnowned(t *testing.T) { diff --git a/go/vt/vtgate/engine/join.go b/go/vt/vtgate/engine/join.go index ef50389c989..45b0d182dd7 100644 --- a/go/vt/vtgate/engine/join.go +++ b/go/vt/vtgate/engine/join.go @@ -225,7 +225,7 @@ func joinFields(lfields, rfields []*querypb.Field, cols []int) []*querypb.Field return fields } -func joinRows(lrow, rrow []sqltypes.Value, cols []int) []sqltypes.Value { +func joinRows(lrow, rrow sqltypes.Row, cols []int) sqltypes.Row { row := make([]sqltypes.Value, len(cols)) for i, index := range cols { if index < 0 { diff --git a/go/vt/vtgate/engine/join_test.go b/go/vt/vtgate/engine/join_test.go index 2df507f9512..eef5810ce69 100644 --- a/go/vt/vtgate/engine/join_test.go +++ b/go/vt/vtgate/engine/join_test.go @@ -89,7 +89,7 @@ func TestJoinExecute(t *testing.T) { `Execute a: type:INT64 value:"10" bv: type:VARCHAR value:"b" false`, `Execute a: type:INT64 value:"10" bv: type:VARCHAR value:"c" false`, }) - expectResult(t, "jn.Execute", r, sqltypes.MakeTestResult( + expectResult(t, r, sqltypes.MakeTestResult( sqltypes.MakeTestFields( "col1|col2|col4|col5", "int64|varchar|int64|varchar", @@ -116,7 +116,7 @@ func TestJoinExecute(t *testing.T) { `Execute a: type:INT64 value:"10" bv: type:VARCHAR value:"b" false`, `Execute a: type:INT64 value:"10" bv: type:VARCHAR value:"c" false`, }) - expectResult(t, "jn.Execute", r, sqltypes.MakeTestResult( + expectResult(t, r, sqltypes.MakeTestResult( sqltypes.MakeTestFields( "col1|col2|col4|col5", "int64|varchar|int64|varchar", @@ -251,7 +251,7 @@ func TestJoinExecuteNoResult(t *testing.T) { "int64|varchar|int64|varchar", ), ) - expectResult(t, "jn.Execute", r, wantResult) + expectResult(t, r, wantResult) } func TestJoinExecuteErrors(t *testing.T) { @@ -389,7 +389,7 @@ func TestJoinStreamExecute(t *testing.T) { `StreamExecute bv: type:VARCHAR value:"b" false`, `StreamExecute bv: type:VARCHAR value:"c" false`, }) - expectResult(t, "jn.Execute", r, sqltypes.MakeTestResult( + expectResult(t, r, sqltypes.MakeTestResult( sqltypes.MakeTestFields( "col1|col2|col4|col5", "int64|varchar|int64|varchar", @@ -418,7 +418,7 @@ func TestJoinStreamExecute(t *testing.T) { `StreamExecute bv: type:VARCHAR value:"b" false`, `StreamExecute bv: type:VARCHAR value:"c" false`, }) - expectResult(t, "jn.Execute", r, sqltypes.MakeTestResult( + expectResult(t, r, sqltypes.MakeTestResult( sqltypes.MakeTestFields( "col1|col2|col4|col5", "int64|varchar|int64|varchar", @@ -475,7 +475,7 @@ func TestGetFields(t *testing.T) { `GetFields bv: `, `Execute bv: true`, }) - expectResult(t, "jn.Execute", r, sqltypes.MakeTestResult( + expectResult(t, r, sqltypes.MakeTestResult( sqltypes.MakeTestFields( "col1|col2|col4|col5", "int64|varchar|int64|varchar", diff --git a/go/vt/vtgate/engine/route_test.go b/go/vt/vtgate/engine/route_test.go index 274ac58c7d4..54c7b5c6fb4 100644 --- a/go/vt/vtgate/engine/route_test.go +++ b/go/vt/vtgate/engine/route_test.go @@ -74,7 +74,7 @@ func TestSelectUnsharded(t *testing.T) { `ResolveDestinations ks [] Destinations:DestinationAllShards()`, `ExecuteMultiShard ks.0: dummy_select {} false false`, }) - expectResult(t, "sel.Execute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) vc.Rewind() result, err = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false) @@ -83,7 +83,7 @@ func TestSelectUnsharded(t *testing.T) { `ResolveDestinations ks [] Destinations:DestinationAllShards()`, `StreamExecuteMulti dummy_select ks.0: {} `, }) - expectResult(t, "sel.StreamExecute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) } func TestInformationSchemaWithTableAndSchemaWithRoutedTables(t *testing.T) { @@ -219,7 +219,7 @@ func TestSelectScatter(t *testing.T) { `ResolveDestinations ks [] Destinations:DestinationAllShards()`, `ExecuteMultiShard ks.-20: dummy_select {} ks.20-: dummy_select {} false false`, }) - expectResult(t, "sel.Execute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) vc.Rewind() result, err = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false) @@ -228,7 +228,7 @@ func TestSelectScatter(t *testing.T) { `ResolveDestinations ks [] Destinations:DestinationAllShards()`, `StreamExecuteMulti dummy_select ks.-20: {} ks.20-: {} `, }) - expectResult(t, "sel.StreamExecute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) } func TestSelectEqualUnique(t *testing.T) { @@ -257,7 +257,7 @@ func TestSelectEqualUnique(t *testing.T) { `ResolveDestinations ks [type:INT64 value:"1"] Destinations:DestinationKeyspaceID(166b40b44aba4bd6)`, `ExecuteMultiShard ks.-20: dummy_select {} false false`, }) - expectResult(t, "sel.Execute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) vc.Rewind() result, err = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false) @@ -266,7 +266,7 @@ func TestSelectEqualUnique(t *testing.T) { `ResolveDestinations ks [type:INT64 value:"1"] Destinations:DestinationKeyspaceID(166b40b44aba4bd6)`, `StreamExecuteMulti dummy_select ks.-20: {} `, }) - expectResult(t, "sel.StreamExecute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) } func TestSelectNone(t *testing.T) { @@ -290,7 +290,7 @@ func TestSelectNone(t *testing.T) { result, err := sel.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) require.NoError(t, err) require.Empty(t, vc.log) - expectResult(t, "sel.Execute", result, &sqltypes.Result{}) + expectResult(t, result, &sqltypes.Result{}) vc.Rewind() result, err = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false) @@ -308,7 +308,7 @@ func TestSelectNone(t *testing.T) { `ResolveDestinations ks [] Destinations:DestinationAnyShard()`, `ExecuteMultiShard ks.-20: dummy_select {} false false`, }) - expectResult(t, "sel.Execute", result, &sqltypes.Result{}) + expectResult(t, result, &sqltypes.Result{}) vc.Rewind() result, err = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false) @@ -317,7 +317,7 @@ func TestSelectNone(t *testing.T) { `ResolveDestinations ks [] Destinations:DestinationAnyShard()`, `StreamExecuteMulti dummy_select ks.-20: {} `, }) - expectResult(t, "sel.StreamExecute", result, &sqltypes.Result{}) + expectResult(t, result, &sqltypes.Result{}) } func TestSelectEqualUniqueScatter(t *testing.T) { @@ -351,7 +351,7 @@ func TestSelectEqualUniqueScatter(t *testing.T) { `ResolveDestinations ks [type:INT64 value:"1"] Destinations:DestinationKeyRange(-)`, `ExecuteMultiShard ks.-20: dummy_select {} ks.20-: dummy_select {} false false`, }) - expectResult(t, "sel.Execute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) vc.Rewind() result, err = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false) @@ -360,7 +360,7 @@ func TestSelectEqualUniqueScatter(t *testing.T) { `ResolveDestinations ks [type:INT64 value:"1"] Destinations:DestinationKeyRange(-)`, `StreamExecuteMulti dummy_select ks.-20: {} ks.20-: {} `, }) - expectResult(t, "sel.StreamExecute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) } func TestSelectEqual(t *testing.T) { @@ -403,7 +403,7 @@ func TestSelectEqual(t *testing.T) { `ResolveDestinations ks [type:INT64 value:"1"] Destinations:DestinationKeyspaceIDs(00,80)`, `ExecuteMultiShard ks.-20: dummy_select {} ks.20-: dummy_select {} false false`, }) - expectResult(t, "sel.Execute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) vc.Rewind() result, err = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false) @@ -413,7 +413,7 @@ func TestSelectEqual(t *testing.T) { `ResolveDestinations ks [type:INT64 value:"1"] Destinations:DestinationKeyspaceIDs(00,80)`, `StreamExecuteMulti dummy_select ks.-20: {} ks.20-: {} `, }) - expectResult(t, "sel.StreamExecute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) } func TestSelectEqualNoRoute(t *testing.T) { @@ -443,7 +443,7 @@ func TestSelectEqualNoRoute(t *testing.T) { `Execute select from, toc from lkp where from in ::from from: type:TUPLE values:{type:INT64 value:"1"} false`, `ResolveDestinations ks [type:INT64 value:"1"] Destinations:DestinationNone()`, }) - expectResult(t, "sel.Execute", result, &sqltypes.Result{}) + expectResult(t, result, &sqltypes.Result{}) vc.Rewind() result, err = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false) @@ -466,7 +466,7 @@ func TestSelectEqualNoRoute(t *testing.T) { `ResolveDestinations ks [] Destinations:DestinationAnyShard()`, `ExecuteMultiShard ks.-20: dummy_select {} false false`, }) - expectResult(t, "sel.Execute", result, &sqltypes.Result{}) + expectResult(t, result, &sqltypes.Result{}) vc.Rewind() result, err = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false) @@ -477,7 +477,7 @@ func TestSelectEqualNoRoute(t *testing.T) { `ResolveDestinations ks [] Destinations:DestinationAnyShard()`, `StreamExecuteMulti dummy_select ks.-20: {} `, }) - expectResult(t, "sel.StreamExecute", result, &sqltypes.Result{}) + expectResult(t, result, &sqltypes.Result{}) } func TestINUnique(t *testing.T) { @@ -513,7 +513,7 @@ func TestINUnique(t *testing.T) { `ks.20-: dummy_select {__vals: type:TUPLE values:{type:INT64 value:"4"}} ` + `false false`, }) - expectResult(t, "sel.Execute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) vc.Rewind() result, err = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false) @@ -522,7 +522,7 @@ func TestINUnique(t *testing.T) { `ResolveDestinations ks [type:INT64 value:"1" type:INT64 value:"2" type:INT64 value:"4"] Destinations:DestinationKeyspaceID(166b40b44aba4bd6),DestinationKeyspaceID(06e7ea22ce92708f),DestinationKeyspaceID(d2fd8867d50d2dfe)`, `StreamExecuteMulti dummy_select ks.-20: {__vals: type:TUPLE values:{type:INT64 value:"1"} values:{type:INT64 value:"2"}} ks.20-: {__vals: type:TUPLE values:{type:INT64 value:"4"}} `, }) - expectResult(t, "sel.StreamExecute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) } func TestINNonUnique(t *testing.T) { @@ -579,7 +579,7 @@ func TestINNonUnique(t *testing.T) { `ks.20-: dummy_select {__vals: type:TUPLE values:{type:INT64 value:"1"} values:{type:INT64 value:"4"}} ` + `false false`, }) - expectResult(t, "sel.Execute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) vc.Rewind() result, err = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false) @@ -589,7 +589,7 @@ func TestINNonUnique(t *testing.T) { `ResolveDestinations ks [type:INT64 value:"1" type:INT64 value:"2" type:INT64 value:"4"] Destinations:DestinationKeyspaceIDs(00,80),DestinationKeyspaceIDs(00),DestinationKeyspaceIDs(80)`, `StreamExecuteMulti dummy_select ks.-20: {__vals: type:TUPLE values:{type:INT64 value:"1"} values:{type:INT64 value:"2"}} ks.20-: {__vals: type:TUPLE values:{type:INT64 value:"1"} values:{type:INT64 value:"4"}} `, }) - expectResult(t, "sel.StreamExecute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) } func TestMultiEqual(t *testing.T) { @@ -623,7 +623,7 @@ func TestMultiEqual(t *testing.T) { `ResolveDestinations ks [type:INT64 value:"1" type:INT64 value:"2" type:INT64 value:"4"] Destinations:DestinationKeyspaceID(166b40b44aba4bd6),DestinationKeyspaceID(06e7ea22ce92708f),DestinationKeyspaceID(d2fd8867d50d2dfe)`, `ExecuteMultiShard ks.-20: dummy_select {} ks.20-: dummy_select {} false false`, }) - expectResult(t, "sel.Execute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) vc.Rewind() result, err = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false) @@ -632,7 +632,7 @@ func TestMultiEqual(t *testing.T) { `ResolveDestinations ks [type:INT64 value:"1" type:INT64 value:"2" type:INT64 value:"4"] Destinations:DestinationKeyspaceID(166b40b44aba4bd6),DestinationKeyspaceID(06e7ea22ce92708f),DestinationKeyspaceID(d2fd8867d50d2dfe)`, `StreamExecuteMulti dummy_select ks.-20: {} ks.20-: {} `, }) - expectResult(t, "sel.StreamExecute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) } func TestSelectLike(t *testing.T) { @@ -670,7 +670,7 @@ func TestSelectLike(t *testing.T) { `ResolveDestinations ks [type:VARCHAR value:"a%"] Destinations:DestinationKeyRange(0c-0d)`, `ExecuteMultiShard ks.-0c80: dummy_select {} ks.0c80-0d: dummy_select {} false false`, }) - expectResult(t, "sel.Execute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) vc.Rewind() @@ -681,7 +681,7 @@ func TestSelectLike(t *testing.T) { `ResolveDestinations ks [type:VARCHAR value:"a%"] Destinations:DestinationKeyRange(0c-0d)`, `StreamExecuteMulti dummy_select ks.-0c80: {} ks.0c80-0d: {} `, }) - expectResult(t, "sel.StreamExecute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) vc.Rewind() @@ -700,7 +700,7 @@ func TestSelectLike(t *testing.T) { `ResolveDestinations ks [type:VARCHAR value:"ab%"] Destinations:DestinationKeyRange(0c92-0c93)`, `ExecuteMultiShard ks.0c80-0d: dummy_select {} false false`, }) - expectResult(t, "sel.Execute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) vc.Rewind() @@ -711,7 +711,7 @@ func TestSelectLike(t *testing.T) { `ResolveDestinations ks [type:VARCHAR value:"ab%"] Destinations:DestinationKeyRange(0c92-0c93)`, `StreamExecuteMulti dummy_select ks.0c80-0d: {} `, }) - expectResult(t, "sel.StreamExecute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) } @@ -736,7 +736,7 @@ func TestSelectNext(t *testing.T) { `ResolveDestinations ks [] Destinations:DestinationAllShards()`, `ExecuteMultiShard ks.-: dummy_select {} false false`, }) - expectResult(t, "sel.Execute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) vc.Rewind() result, _ = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false) @@ -744,7 +744,7 @@ func TestSelectNext(t *testing.T) { `ResolveDestinations ks [] Destinations:DestinationAllShards()`, `StreamExecuteMulti dummy_select ks.-: {} `, }) - expectResult(t, "sel.StreamExecute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) } func TestSelectDBA(t *testing.T) { @@ -768,7 +768,7 @@ func TestSelectDBA(t *testing.T) { `ResolveDestinations ks [] Destinations:DestinationAnyShard()`, `ExecuteMultiShard ks.-20: dummy_select {} false false`, }) - expectResult(t, "sel.Execute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) vc.Rewind() result, _ = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false) @@ -776,7 +776,7 @@ func TestSelectDBA(t *testing.T) { `ResolveDestinations ks [] Destinations:DestinationAnyShard()`, `StreamExecuteMulti dummy_select ks.-20: {} `, }) - expectResult(t, "sel.StreamExecute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) } func TestSelectReference(t *testing.T) { @@ -800,7 +800,7 @@ func TestSelectReference(t *testing.T) { `ResolveDestinations ks [] Destinations:DestinationAnyShard()`, `ExecuteMultiShard ks.-20: dummy_select {} false false`, }) - expectResult(t, "sel.Execute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) vc.Rewind() result, _ = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false) @@ -808,7 +808,7 @@ func TestSelectReference(t *testing.T) { `ResolveDestinations ks [] Destinations:DestinationAnyShard()`, `StreamExecuteMulti dummy_select ks.-20: {} `, }) - expectResult(t, "sel.StreamExecute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) } func TestRouteGetFields(t *testing.T) { @@ -840,7 +840,7 @@ func TestRouteGetFields(t *testing.T) { `ResolveDestinations ks [] Destinations:DestinationAnyShard()`, `ExecuteMultiShard ks.-20: dummy_select_field {} false false`, }) - expectResult(t, "sel.Execute", result, &sqltypes.Result{}) + expectResult(t, result, &sqltypes.Result{}) vc.Rewind() result, err = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, true) @@ -851,7 +851,7 @@ func TestRouteGetFields(t *testing.T) { `ResolveDestinations ks [] Destinations:DestinationAnyShard()`, `ExecuteMultiShard ks.-20: dummy_select_field {} false false`, }) - expectResult(t, "sel.StreamExecute", result, &sqltypes.Result{}) + expectResult(t, result, &sqltypes.Result{}) vc.Rewind() // test with special no-routes handling @@ -864,7 +864,7 @@ func TestRouteGetFields(t *testing.T) { `ResolveDestinations ks [] Destinations:DestinationAnyShard()`, `ExecuteMultiShard ks.-20: dummy_select {} false false`, }) - expectResult(t, "sel.Execute", result, &sqltypes.Result{}) + expectResult(t, result, &sqltypes.Result{}) vc.Rewind() result, err = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, true) @@ -875,7 +875,7 @@ func TestRouteGetFields(t *testing.T) { `ResolveDestinations ks [] Destinations:DestinationAnyShard()`, `StreamExecuteMulti dummy_select ks.-20: {} `, }) - expectResult(t, "sel.StreamExecute", result, &sqltypes.Result{}) + expectResult(t, result, &sqltypes.Result{}) } func TestRouteSort(t *testing.T) { @@ -924,7 +924,7 @@ func TestRouteSort(t *testing.T) { "2", "3", ) - expectResult(t, "sel.Execute", result, wantResult) + expectResult(t, result, wantResult) sel.OrderBy[0].Desc = true vc.Rewind() @@ -940,7 +940,7 @@ func TestRouteSort(t *testing.T) { "1", "1", ) - expectResult(t, "sel.Execute", result, wantResult) + expectResult(t, result, wantResult) vc = &loggingVCursor{ shards: []string{"0"}, @@ -1013,7 +1013,7 @@ func TestRouteSortWeightStrings(t *testing.T) { "g|d", "v|x", ) - expectResult(t, "sel.Execute", result, wantResult) + expectResult(t, result, wantResult) }) t.Run("Descending ordering using weighted strings", func(t *testing.T) { @@ -1032,7 +1032,7 @@ func TestRouteSortWeightStrings(t *testing.T) { "c|t", "a|a", ) - expectResult(t, "sel.Execute", result, wantResult) + expectResult(t, result, wantResult) }) t.Run("Error when no weight string set", func(t *testing.T) { @@ -1118,7 +1118,7 @@ func TestRouteSortCollation(t *testing.T) { "cs", "d", ) - expectResult(t, "sel.Execute", result, wantResult) + expectResult(t, result, wantResult) }) t.Run("Descending ordering using Collation", func(t *testing.T) { @@ -1137,7 +1137,7 @@ func TestRouteSortCollation(t *testing.T) { "c", "c", ) - expectResult(t, "sel.Execute", result, wantResult) + expectResult(t, result, wantResult) }) t.Run("Error when Unknown Collation", func(t *testing.T) { @@ -1239,7 +1239,7 @@ func TestRouteSortTruncate(t *testing.T) { "2", "3", ) - expectResult(t, "sel.Execute", result, wantResult) + expectResult(t, result, wantResult) } func TestRouteStreamTruncate(t *testing.T) { @@ -1281,7 +1281,7 @@ func TestRouteStreamTruncate(t *testing.T) { "1", "2", ) - expectResult(t, "sel.Execute", result, wantResult) + expectResult(t, result, wantResult) } func TestRouteStreamSortTruncate(t *testing.T) { @@ -1330,7 +1330,7 @@ func TestRouteStreamSortTruncate(t *testing.T) { "1", "2", ) - expectResult(t, "sel.Execute", result, wantResult) + expectResult(t, result, wantResult) } func TestParamsFail(t *testing.T) { @@ -1432,7 +1432,7 @@ func TestExecFail(t *testing.T) { `ResolveDestinations ks [] Destinations:DestinationAllShards()`, `ExecuteMultiShard ks.-20: dummy_select {} ks.20-: dummy_select {} false false`, }) - expectResult(t, "sel.Execute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) vc.Rewind() vc.resultErr = sqlerror.NewSQLError(sqlerror.ERQueryInterrupted, "", "query timeout -20") @@ -1474,7 +1474,7 @@ func TestSelectEqualUniqueMultiColumnVindex(t *testing.T) { `ResolveDestinationsMultiCol ks [[INT64(1) INT64(2)]] Destinations:DestinationKeyspaceID(0106e7ea22ce92708f)`, `ExecuteMultiShard ks.-20: dummy_select {} false false`, }) - expectResult(t, "sel.Execute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) vc.Rewind() result, err = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false) @@ -1483,7 +1483,7 @@ func TestSelectEqualUniqueMultiColumnVindex(t *testing.T) { `ResolveDestinationsMultiCol ks [[INT64(1) INT64(2)]] Destinations:DestinationKeyspaceID(0106e7ea22ce92708f)`, `StreamExecuteMulti dummy_select ks.-20: {} `, }) - expectResult(t, "sel.StreamExecute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) } func TestSelectEqualMultiColumnVindex(t *testing.T) { @@ -1511,7 +1511,7 @@ func TestSelectEqualMultiColumnVindex(t *testing.T) { `ResolveDestinationsMultiCol ks [[INT64(32)]] Destinations:DestinationKeyRange(20-21)`, `ExecuteMultiShard ks.-20: dummy_select {} ks.20-: dummy_select {} false false`, }) - expectResult(t, "sel.StreamExecute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) vc.Rewind() result, err = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false) @@ -1520,7 +1520,7 @@ func TestSelectEqualMultiColumnVindex(t *testing.T) { `ResolveDestinationsMultiCol ks [[INT64(32)]] Destinations:DestinationKeyRange(20-21)`, `StreamExecuteMulti dummy_select ks.-20: {} ks.20-: {} `, }) - expectResult(t, "sel.StreamExecute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) } func TestINMultiColumnVindex(t *testing.T) { @@ -1557,7 +1557,7 @@ func TestINMultiColumnVindex(t *testing.T) { `ResolveDestinationsMultiCol ks [[INT64(1) INT64(3)] [INT64(1) INT64(4)] [INT64(2) INT64(3)] [INT64(2) INT64(4)]] Destinations:DestinationKeyspaceID(014eb190c9a2fa169c),DestinationKeyspaceID(01d2fd8867d50d2dfe),DestinationKeyspaceID(024eb190c9a2fa169c),DestinationKeyspaceID(02d2fd8867d50d2dfe)`, `ExecuteMultiShard ks.-20: dummy_select {__vals0: type:TUPLE values:{type:INT64 value:"1"} __vals1: type:TUPLE values:{type:INT64 value:"3"}} ks.20-: dummy_select {__vals0: type:TUPLE values:{type:INT64 value:"1"} values:{type:INT64 value:"2"} __vals1: type:TUPLE values:{type:INT64 value:"4"} values:{type:INT64 value:"3"}} false false`, }) - expectResult(t, "sel.Execute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) vc.Rewind() result, err = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false) @@ -1566,7 +1566,7 @@ func TestINMultiColumnVindex(t *testing.T) { `ResolveDestinationsMultiCol ks [[INT64(1) INT64(3)] [INT64(1) INT64(4)] [INT64(2) INT64(3)] [INT64(2) INT64(4)]] Destinations:DestinationKeyspaceID(014eb190c9a2fa169c),DestinationKeyspaceID(01d2fd8867d50d2dfe),DestinationKeyspaceID(024eb190c9a2fa169c),DestinationKeyspaceID(02d2fd8867d50d2dfe)`, `StreamExecuteMulti dummy_select ks.-20: {__vals0: type:TUPLE values:{type:INT64 value:"1"} __vals1: type:TUPLE values:{type:INT64 value:"3"}} ks.20-: {__vals0: type:TUPLE values:{type:INT64 value:"1"} values:{type:INT64 value:"2"} __vals1: type:TUPLE values:{type:INT64 value:"4"} values:{type:INT64 value:"3"}} `, }) - expectResult(t, "sel.StreamExecute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) } func TestINMixedMultiColumnComparision(t *testing.T) { @@ -1600,7 +1600,7 @@ func TestINMixedMultiColumnComparision(t *testing.T) { `ResolveDestinationsMultiCol ks [[INT64(1) INT64(3)] [INT64(1) INT64(4)]] Destinations:DestinationKeyspaceID(014eb190c9a2fa169c),DestinationKeyspaceID(01d2fd8867d50d2dfe)`, `ExecuteMultiShard ks.-20: dummy_select {__vals1: type:TUPLE values:{type:INT64 value:"3"}} ks.20-: dummy_select {__vals1: type:TUPLE values:{type:INT64 value:"4"}} false false`, }) - expectResult(t, "sel.Execute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) vc.Rewind() result, err = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false) @@ -1609,7 +1609,7 @@ func TestINMixedMultiColumnComparision(t *testing.T) { `ResolveDestinationsMultiCol ks [[INT64(1) INT64(3)] [INT64(1) INT64(4)]] Destinations:DestinationKeyspaceID(014eb190c9a2fa169c),DestinationKeyspaceID(01d2fd8867d50d2dfe)`, `StreamExecuteMulti dummy_select ks.-20: {__vals1: type:TUPLE values:{type:INT64 value:"3"}} ks.20-: {__vals1: type:TUPLE values:{type:INT64 value:"4"}} `, }) - expectResult(t, "sel.StreamExecute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) } func TestMultiEqualMultiCol(t *testing.T) { @@ -1643,7 +1643,7 @@ func TestMultiEqualMultiCol(t *testing.T) { `ResolveDestinationsMultiCol ks [[INT64(1) INT64(2)] [INT64(3) INT64(4)]] Destinations:DestinationKeyspaceID(0106e7ea22ce92708f),DestinationKeyspaceID(03d2fd8867d50d2dfe)`, `ExecuteMultiShard ks.-20: dummy_select {} ks.40-: dummy_select {} false false`, }) - expectResult(t, "sel.Execute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) vc.Rewind() result, err = wrapStreamExecute(sel, vc, map[string]*querypb.BindVariable{}, false) @@ -1652,7 +1652,7 @@ func TestMultiEqualMultiCol(t *testing.T) { `ResolveDestinationsMultiCol ks [[INT64(1) INT64(2)] [INT64(3) INT64(4)]] Destinations:DestinationKeyspaceID(0106e7ea22ce92708f),DestinationKeyspaceID(03d2fd8867d50d2dfe)`, `StreamExecuteMulti dummy_select ks.-20: {} ks.40-: {} `, }) - expectResult(t, "sel.StreamExecute", result, defaultSelectResult) + expectResult(t, result, defaultSelectResult) } func TestBuildRowColValues(t *testing.T) { diff --git a/go/vt/vtgate/engine/semi_join_test.go b/go/vt/vtgate/engine/semi_join_test.go index ca89882ab8a..9cf55d4f78f 100644 --- a/go/vt/vtgate/engine/semi_join_test.go +++ b/go/vt/vtgate/engine/semi_join_test.go @@ -152,7 +152,7 @@ func TestSemiJoinStreamExecute(t *testing.T) { `StreamExecute bv: type:VARCHAR value:"c" false`, `StreamExecute bv: type:VARCHAR value:"d" false`, }) - expectResult(t, "jn.Execute", r, sqltypes.MakeTestResult( + expectResult(t, r, sqltypes.MakeTestResult( sqltypes.MakeTestFields( "col1|col2|col3", "int64|varchar|varchar", diff --git a/go/vt/vtgate/engine/simple_projection_test.go b/go/vt/vtgate/engine/simple_projection_test.go index 99d644c93af..6fdc288095c 100644 --- a/go/vt/vtgate/engine/simple_projection_test.go +++ b/go/vt/vtgate/engine/simple_projection_test.go @@ -59,7 +59,7 @@ func TestSubqueryExecute(t *testing.T) { prim.ExpectLog(t, []string{ `Execute a: type:INT64 value:"1" true`, }) - expectResult(t, "sq.Execute", r, sqltypes.MakeTestResult( + expectResult(t, r, sqltypes.MakeTestResult( sqltypes.MakeTestFields( "col1|col3", "int64|varchar", @@ -108,7 +108,7 @@ func TestSubqueryStreamExecute(t *testing.T) { prim.ExpectLog(t, []string{ `StreamExecute a: type:INT64 value:"1" true`, }) - expectResult(t, "sq.Execute", r, sqltypes.MakeTestResult( + expectResult(t, r, sqltypes.MakeTestResult( sqltypes.MakeTestFields( "col1|col3", "int64|varchar", @@ -158,7 +158,7 @@ func TestSubqueryGetFields(t *testing.T) { `GetFields a: type:INT64 value:"1"`, `Execute a: type:INT64 value:"1" true`, }) - expectResult(t, "sq.Execute", r, sqltypes.MakeTestResult( + expectResult(t, r, sqltypes.MakeTestResult( sqltypes.MakeTestFields( "col1|col3", "int64|varchar", diff --git a/go/vt/vtgate/engine/uncorrelated_subquery_test.go b/go/vt/vtgate/engine/uncorrelated_subquery_test.go index 3e80c6369a7..085fe09238f 100644 --- a/go/vt/vtgate/engine/uncorrelated_subquery_test.go +++ b/go/vt/vtgate/engine/uncorrelated_subquery_test.go @@ -65,7 +65,7 @@ func TestPulloutSubqueryValueGood(t *testing.T) { require.NoError(t, err) sfp.ExpectLog(t, []string{`Execute aa: type:INT64 value:"1" false`}) ufp.ExpectLog(t, []string{`Execute aa: type:INT64 value:"1" sq: type:INT64 value:"1" false`}) - expectResult(t, "ps.Execute", result, underlyingResult) + expectResult(t, result, underlyingResult) } func TestPulloutSubqueryValueNone(t *testing.T) { @@ -279,7 +279,7 @@ func TestPulloutSubqueryStream(t *testing.T) { require.NoError(t, err) sfp.ExpectLog(t, []string{`Execute aa: type:INT64 value:"1" false`}) ufp.ExpectLog(t, []string{`StreamExecute aa: type:INT64 value:"1" sq: type:INT64 value:"1" true`}) - expectResult(t, "ps.StreamExecute", result, underlyingResult) + expectResult(t, result, underlyingResult) } func TestPulloutSubqueryGetFields(t *testing.T) { diff --git a/go/vt/vtgate/evalengine/api_coerce.go b/go/vt/vtgate/evalengine/api_coerce.go index 130727d8f31..5e92431e555 100644 --- a/go/vt/vtgate/evalengine/api_coerce.go +++ b/go/vt/vtgate/evalengine/api_coerce.go @@ -19,6 +19,8 @@ package evalengine import ( "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" ) func CoerceTo(value sqltypes.Value, typ sqltypes.Type) (sqltypes.Value, error) { @@ -28,3 +30,69 @@ func CoerceTo(value sqltypes.Value, typ sqltypes.Type) (sqltypes.Value, error) { } return evalToSQLValueWithType(cast, typ), nil } + +// CoerceTypes takes two input types, and decides how they should be coerced before compared +func CoerceTypes(v1, v2 Type) (out Type, err error) { + if v1 == v2 { + return v1, nil + } + if sqltypes.IsNull(v1.Type) || sqltypes.IsNull(v2.Type) { + return Type{Type: sqltypes.Null, Coll: collations.CollationBinaryID, Nullable: true}, nil + } + fail := func() error { + return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "types does not support hashcode yet: %v vs %v", v1.Type, v2.Type) + } + + out = Type{Nullable: v1.Nullable || v2.Nullable} + + switch { + case sqltypes.IsTextOrBinary(v1.Type) && sqltypes.IsTextOrBinary(v2.Type): + out.Type = sqltypes.VarChar + mergedCollation, _, _, ferr := mergeCollations(typedCoercionCollation(v1.Type, v1.Coll), typedCoercionCollation(v2.Type, v2.Coll), v1.Type, v2.Type) + if err != nil { + return Type{}, ferr + } + out.Coll = mergedCollation.Collation + return + + case sqltypes.IsDateOrTime(v1.Type): + out.Coll = collations.CollationBinaryID + out.Type = v1.Type + return + + case sqltypes.IsDateOrTime(v2.Type): + out.Coll = collations.CollationBinaryID + out.Type = v2.Type + return + + case sqltypes.IsNumber(v1.Type) || sqltypes.IsNumber(v2.Type): + out.Coll = collations.CollationBinaryID + switch { + case sqltypes.IsTextOrBinary(v1.Type) || sqltypes.IsFloat(v1.Type) || sqltypes.IsDecimal(v1.Type) || + sqltypes.IsTextOrBinary(v2.Type) || sqltypes.IsFloat(v2.Type) || sqltypes.IsDecimal(v2.Type): + out.Type = sqltypes.Float64 + return + case sqltypes.IsSigned(v1.Type): + switch { + case sqltypes.IsUnsigned(v2.Type): + out.Type = sqltypes.Uint64 + return + case sqltypes.IsSigned(v2.Type): + out.Type = sqltypes.Int64 + return + default: + return Type{}, fail() + } + case sqltypes.IsUnsigned(v1.Type): + switch { + case sqltypes.IsSigned(v2.Type) || sqltypes.IsUnsigned(v2.Type): + out.Type = sqltypes.Uint64 + return + default: + return Type{}, fail() + } + } + } + + return Type{}, fail() +} diff --git a/go/vt/vtgate/evalengine/api_hash_test.go b/go/vt/vtgate/evalengine/api_hash_test.go index 832a1ed3b88..0add16de89d 100644 --- a/go/vt/vtgate/evalengine/api_hash_test.go +++ b/go/vt/vtgate/evalengine/api_hash_test.go @@ -24,13 +24,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" - "vitess.io/vitess/go/vt/vthash" - - "vitess.io/vitess/go/mysql/collations" - "vitess.io/vitess/go/sqltypes" ) func TestHashCodes(t *testing.T) { @@ -197,7 +195,7 @@ func coerceTo(v1, v2 sqltypes.Type) (sqltypes.Type, error) { if sqltypes.IsNull(v1) || sqltypes.IsNull(v2) { return sqltypes.Null, nil } - if (sqltypes.IsText(v1) || sqltypes.IsBinary(v1)) && (sqltypes.IsText(v2) || sqltypes.IsBinary(v2)) { + if (sqltypes.IsTextOrBinary(v1)) && (sqltypes.IsTextOrBinary(v2)) { return sqltypes.VarChar, nil } if sqltypes.IsDateOrTime(v1) { @@ -209,7 +207,7 @@ func coerceTo(v1, v2 sqltypes.Type) (sqltypes.Type, error) { if sqltypes.IsNumber(v1) || sqltypes.IsNumber(v2) { switch { - case sqltypes.IsText(v1) || sqltypes.IsBinary(v1) || sqltypes.IsText(v2) || sqltypes.IsBinary(v2): + case sqltypes.IsTextOrBinary(v1) || sqltypes.IsTextOrBinary(v2): return sqltypes.Float64, nil case sqltypes.IsFloat(v2) || v2 == sqltypes.Decimal || sqltypes.IsFloat(v1) || v1 == sqltypes.Decimal: return sqltypes.Float64, nil diff --git a/go/vt/vtgate/evalengine/collation.go b/go/vt/vtgate/evalengine/collation.go index 7cb341f52b0..b4e589c9724 100644 --- a/go/vt/vtgate/evalengine/collation.go +++ b/go/vt/vtgate/evalengine/collation.go @@ -59,8 +59,8 @@ func mergeCollations(c1, c2 collations.TypedCollation, t1, t2 sqltypes.Type) (co return c1, nil, nil, nil } - lt := sqltypes.IsText(t1) || sqltypes.IsBinary(t1) - rt := sqltypes.IsText(t2) || sqltypes.IsBinary(t2) + lt := sqltypes.IsTextOrBinary(t1) + rt := sqltypes.IsTextOrBinary(t2) if !lt || !rt { if lt { return c1, nil, nil, nil diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 21a25ad3163..1c2feeb5f15 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -67,7 +67,7 @@ func (ct ctype) nullable() bool { } func (ct ctype) isTextual() bool { - return sqltypes.IsText(ct.Type) || sqltypes.IsBinary(ct.Type) + return sqltypes.IsTextOrBinary(ct.Type) } func (ct ctype) isHexOrBitLiteral() bool { diff --git a/go/vt/vtgate/evalengine/eval.go b/go/vt/vtgate/evalengine/eval.go index e327b9d5651..82f1ec688c7 100644 --- a/go/vt/vtgate/evalengine/eval.go +++ b/go/vt/vtgate/evalengine/eval.go @@ -312,7 +312,7 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I return newEvalUint64(uint64(i.i)), nil } - case sqltypes.IsText(typ) || sqltypes.IsBinary(typ): + case sqltypes.IsTextOrBinary(typ): switch { case v.IsText() || v.IsBinary(): return newEvalRaw(v.Type(), v.Raw(), typedCoercionCollation(v.Type(), collation)), nil diff --git a/go/vt/vtgate/evalengine/expr_compare.go b/go/vt/vtgate/evalengine/expr_compare.go index 6639bb7f7e2..90c2d313b07 100644 --- a/go/vt/vtgate/evalengine/expr_compare.go +++ b/go/vt/vtgate/evalengine/expr_compare.go @@ -114,7 +114,7 @@ func (compareNullSafeEQ) compare(left, right eval) (boolean, error) { } func typeIsTextual(tt sqltypes.Type) bool { - return sqltypes.IsText(tt) || sqltypes.IsBinary(tt) || tt == sqltypes.Time + return sqltypes.IsTextOrBinary(tt) || tt == sqltypes.Time } func compareAsStrings(l, r sqltypes.Type) bool { diff --git a/go/vt/vtgate/planbuilder/join.go b/go/vt/vtgate/planbuilder/join.go index 02027a8b49e..462b45fa00a 100644 --- a/go/vt/vtgate/planbuilder/join.go +++ b/go/vt/vtgate/planbuilder/join.go @@ -55,3 +55,16 @@ func (j *join) Primitive() engine.Primitive { Opcode: j.Opcode, } } + +type hashJoin struct { + lhs, rhs logicalPlan + inner *engine.HashJoin +} + +func (hj *hashJoin) Primitive() engine.Primitive { + lhs := hj.lhs.Primitive() + rhs := hj.rhs.Primitive() + hj.inner.Left = lhs + hj.inner.Right = rhs + return hj.inner +} diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 5f59a96000a..7d3223a48e9 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -68,6 +68,8 @@ func transformToLogicalPlan(ctx *plancontext.PlanningContext, op ops.Operator) ( return transformFkVerify(ctx, op) case *operators.InsertSelection: return transformInsertionSelection(ctx, op) + case *operators.HashJoin: + return transformHashJoin(ctx, op) case *operators.Sequential: return transformSequential(ctx, op) } @@ -810,3 +812,58 @@ func createLimit(input logicalPlan, limit *sqlparser.Limit) (logicalPlan, error) return plan, nil } + +func transformHashJoin(ctx *plancontext.PlanningContext, op *operators.HashJoin) (logicalPlan, error) { + lhs, err := transformToLogicalPlan(ctx, op.LHS) + if err != nil { + return nil, err + } + rhs, err := transformToLogicalPlan(ctx, op.RHS) + if err != nil { + return nil, err + } + + if len(op.LHSKeys) != 1 { + return nil, vterrors.VT12001("hash joins must have exactly one join predicate") + } + + joinOp := engine.InnerJoin + if op.LeftJoin { + joinOp = engine.LeftJoin + } + + var missingTypes []string + + ltyp, found := ctx.SemTable.TypeForExpr(op.JoinComparisons[0].LHS) + if !found { + missingTypes = append(missingTypes, sqlparser.String(op.JoinComparisons[0].LHS)) + } + rtyp, found := ctx.SemTable.TypeForExpr(op.JoinComparisons[0].RHS) + if !found { + missingTypes = append(missingTypes, sqlparser.String(op.JoinComparisons[0].RHS)) + } + + if len(missingTypes) > 0 { + return nil, vterrors.VT12001( + fmt.Sprintf("missing type information for [%s]", strings.Join(missingTypes, ", "))) + } + + comparisonType, err := evalengine.CoerceTypes(ltyp, rtyp) + if err != nil { + return nil, err + } + + return &hashJoin{ + lhs: lhs, + rhs: rhs, + inner: &engine.HashJoin{ + Opcode: joinOp, + Cols: op.ColumnOffsets, + LHSKey: op.LHSKeys[0], + RHSKey: op.RHSKeys[0], + ASTPred: op.JoinPredicate(), + Collation: comparisonType.Coll, + ComparisonType: comparisonType.Type, + }, + }, nil +} diff --git a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go index 0127e81b4c4..edba5c51256 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go +++ b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go @@ -116,12 +116,12 @@ func pushAggregationThroughSubquery( src.Outer = pushedAggr if !rootAggr.Original { - return src, rewrite.NewTree("push Aggregation under subquery - keep original", rootAggr), nil + return src, rewrite.NewTree("push Aggregation under subquery - keep original"), nil } rootAggr.aggregateTheAggregates() - return rootAggr, rewrite.NewTree("push Aggregation under subquery", rootAggr), nil + return rootAggr, rewrite.NewTree("push Aggregation under subquery"), nil } func (a *Aggregator) aggregateTheAggregates() { @@ -161,10 +161,10 @@ func pushAggregationThroughRoute( if !aggregator.Original { // we only keep the root aggregation, if this aggregator was created // by splitting one and pushing under a join, we can get rid of this one - return aggregator.Source, rewrite.NewTree("push aggregation under route - remove original", aggregator), nil + return aggregator.Source, rewrite.NewTree("push aggregation under route - remove original"), nil } - return aggregator, rewrite.NewTree("push aggregation under route - keep original", aggregator), nil + return aggregator, rewrite.NewTree("push aggregation under route - keep original"), nil } // pushAggregations splits aggregations between the original aggregator and the one we are pushing down @@ -264,10 +264,10 @@ withNextColumn: if !aggregator.Original { // we only keep the root aggregation, if this aggregator was created // by splitting one and pushing under a join, we can get rid of this one - return aggregator.Source, rewrite.NewTree("push aggregation under filter - remove original", aggregator), nil + return aggregator.Source, rewrite.NewTree("push aggregation under filter - remove original"), nil } aggregator.aggregateTheAggregates() - return aggregator, rewrite.NewTree("push aggregation under filter - keep original", aggregator), nil + return aggregator, rewrite.NewTree("push aggregation under filter - keep original"), nil } func collectColNamesNeeded(ctx *plancontext.PlanningContext, f *Filter) (columnsNeeded []*sqlparser.ColName) { @@ -411,12 +411,12 @@ func pushAggregationThroughJoin(ctx *plancontext.PlanningContext, rootAggr *Aggr if !rootAggr.Original { // we only keep the root aggregation, if this aggregator was created // by splitting one and pushing under a join, we can get rid of this one - return output, rewrite.NewTree("push Aggregation under join - keep original", rootAggr), nil + return output, rewrite.NewTree("push Aggregation under join - keep original"), nil } rootAggr.aggregateTheAggregates() rootAggr.Source = output - return rootAggr, rewrite.NewTree("push Aggregation under join", rootAggr), nil + return rootAggr, rewrite.NewTree("push Aggregation under join"), nil } var errAbortAggrPushing = fmt.Errorf("abort aggregation pushing") @@ -473,7 +473,7 @@ func splitGroupingToLeftAndRight(ctx *plancontext.PlanningContext, rootAggr *Agg RHSExpr: expr, }) case deps.IsSolvedBy(lhs.tableID.Merge(rhs.tableID)): - jc, err := BreakExpressionInLHSandRHS(ctx, groupBy.SimplifiedExpr, lhs.tableID) + jc, err := breakExpressionInLHSandRHSForApplyJoin(ctx, groupBy.SimplifiedExpr, lhs.tableID) if err != nil { return nil, err } @@ -877,5 +877,5 @@ func splitAvgAggregations(ctx *plancontext.PlanningContext, aggr *Aggregator) (o aggr.Columns = append(aggr.Columns, columns...) aggr.Aggregations = append(aggr.Aggregations, aggregations...) - return proj, rewrite.NewTree("split avg aggregation", proj), nil + return proj, rewrite.NewTree("split avg aggregation"), nil } diff --git a/go/vt/vtgate/planbuilder/operators/aggregator.go b/go/vt/vtgate/planbuilder/operators/aggregator.go index 33846f83365..685a418339a 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregator.go +++ b/go/vt/vtgate/planbuilder/operators/aggregator.go @@ -259,16 +259,16 @@ func (a *Aggregator) GetOrdering(ctx *plancontext.PlanningContext) []ops.OrderBy return a.Source.GetOrdering(ctx) } -func (a *Aggregator) planOffsets(ctx *plancontext.PlanningContext) { +func (a *Aggregator) planOffsets(ctx *plancontext.PlanningContext) ops.Operator { if a.offsetPlanned { - return + return nil } defer func() { a.offsetPlanned = true }() if !a.Pushed { a.planOffsetsNotPushed(ctx) - return + return nil } for idx, gb := range a.Grouping { @@ -291,6 +291,7 @@ func (a *Aggregator) planOffsets(ctx *plancontext.PlanningContext) { offset := a.internalAddColumn(ctx, aeWrap(weightStringFor(aggr.Func.GetArg())), true) a.Aggregations[idx].WSOffset = offset } + return nil } func (aggr Aggr) getPushColumn() sqlparser.Expr { diff --git a/go/vt/vtgate/planbuilder/operators/apply_join.go b/go/vt/vtgate/planbuilder/operators/apply_join.go index 138c17f2da7..95d7d962738 100644 --- a/go/vt/vtgate/planbuilder/operators/apply_join.go +++ b/go/vt/vtgate/planbuilder/operators/apply_join.go @@ -148,18 +148,16 @@ func (aj *ApplyJoin) IsInner() bool { return !aj.LeftJoin } -func (aj *ApplyJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) error { +func (aj *ApplyJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) { aj.Predicate = ctx.SemTable.AndExpressions(expr, aj.Predicate) - col, err := BreakExpressionInLHSandRHS(ctx, expr, TableID(aj.LHS)) + col, err := breakExpressionInLHSandRHSForApplyJoin(ctx, expr, TableID(aj.LHS)) if err != nil { - return err + panic(err) } aj.JoinPredicates = append(aj.JoinPredicates, col) rhs := aj.RHS.AddPredicate(ctx, col.RHSExpr) aj.RHS = rhs - - return nil } func (aj *ApplyJoin) pushColRight(ctx *plancontext.PlanningContext, e *sqlparser.AliasedExpr, addToGroupBy bool) (int, error) { @@ -203,7 +201,7 @@ func (aj *ApplyJoin) getJoinColumnFor(ctx *plancontext.PlanningContext, orig *sq case deps.IsSolvedBy(rhs): col.RHSExpr = e case deps.IsSolvedBy(both): - col, err = BreakExpressionInLHSandRHS(ctx, e, TableID(aj.LHS)) + col, err = breakExpressionInLHSandRHSForApplyJoin(ctx, e, TableID(aj.LHS)) if err != nil { return JoinColumn{}, err } @@ -243,7 +241,7 @@ func (aj *ApplyJoin) AddColumn( return offset } -func (aj *ApplyJoin) planOffsets(ctx *plancontext.PlanningContext) { +func (aj *ApplyJoin) planOffsets(ctx *plancontext.PlanningContext) ops.Operator { for _, col := range aj.JoinColumns { // Read the type description for JoinColumn to understand the following code for _, lhsExpr := range col.LHSExprs { @@ -272,6 +270,8 @@ func (aj *ApplyJoin) planOffsets(ctx *plancontext.PlanningContext) { offset := aj.LHS.AddColumn(ctx, true, false, aeWrap(lhsExpr.Expr)) aj.Vars[lhsExpr.Name] = offset } + + return nil } func (aj *ApplyJoin) addOffset(offset int) { diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index 64d9826a80e..46286b2778f 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -165,7 +165,7 @@ func (jpc *joinPredicateCollector) inspectPredicate( // then we can use this predicate to connect the subquery to the outer query if !deps.IsSolvedBy(jpc.subqID) && deps.IsSolvedBy(jpc.totalID) { jpc.predicates = append(jpc.predicates, predicate) - jc, err := BreakExpressionInLHSandRHS(ctx, predicate, jpc.outerID) + jc, err := breakExpressionInLHSandRHSForApplyJoin(ctx, predicate, jpc.outerID) if err != nil { return err } diff --git a/go/vt/vtgate/planbuilder/operators/distinct.go b/go/vt/vtgate/planbuilder/operators/distinct.go index 88503514615..d7aad08d206 100644 --- a/go/vt/vtgate/planbuilder/operators/distinct.go +++ b/go/vt/vtgate/planbuilder/operators/distinct.go @@ -46,7 +46,7 @@ type ( } ) -func (d *Distinct) planOffsets(ctx *plancontext.PlanningContext) { +func (d *Distinct) planOffsets(ctx *plancontext.PlanningContext) ops.Operator { columns := d.GetColumns(ctx) for idx, col := range columns { e, err := d.QP.GetSimplifiedExpr(ctx, col.Expr) @@ -68,6 +68,7 @@ func (d *Distinct) planOffsets(ctx *plancontext.PlanningContext) { Type: typ, }) } + return nil } func (d *Distinct) Clone(inputs []ops.Operator) ops.Operator { diff --git a/go/vt/vtgate/planbuilder/operators/expressions.go b/go/vt/vtgate/planbuilder/operators/expressions.go index 7ab27e787e8..0df875a6fbd 100644 --- a/go/vt/vtgate/planbuilder/operators/expressions.go +++ b/go/vt/vtgate/planbuilder/operators/expressions.go @@ -22,9 +22,9 @@ import ( "vitess.io/vitess/go/vt/vtgate/semantics" ) -// BreakExpressionInLHSandRHS takes an expression and +// breakExpressionInLHSandRHSForApplyJoin takes an expression and // extracts the parts that are coming from one of the sides into `ColName`s that are needed -func BreakExpressionInLHSandRHS( +func breakExpressionInLHSandRHSForApplyJoin( ctx *plancontext.PlanningContext, expr sqlparser.Expr, lhs semantics.TableSet, diff --git a/go/vt/vtgate/planbuilder/operators/filter.go b/go/vt/vtgate/planbuilder/operators/filter.go index ed43910b75d..cee57c74943 100644 --- a/go/vt/vtgate/planbuilder/operators/filter.go +++ b/go/vt/vtgate/planbuilder/operators/filter.go @@ -107,7 +107,7 @@ func (f *Filter) GetOrdering(ctx *plancontext.PlanningContext) []ops.OrderBy { func (f *Filter) Compact(*plancontext.PlanningContext) (ops.Operator, *rewrite.ApplyResult, error) { if len(f.Predicates) == 0 { - return f.Source, rewrite.NewTree("filter with no predicates removed", f), nil + return f.Source, rewrite.NewTree("filter with no predicates removed"), nil } other, isFilter := f.Source.(*Filter) @@ -116,10 +116,10 @@ func (f *Filter) Compact(*plancontext.PlanningContext) (ops.Operator, *rewrite.A } f.Source = other.Source f.Predicates = append(f.Predicates, other.Predicates...) - return f, rewrite.NewTree("two filters merged into one", f), nil + return f, rewrite.NewTree("two filters merged into one"), nil } -func (f *Filter) planOffsets(ctx *plancontext.PlanningContext) { +func (f *Filter) planOffsets(ctx *plancontext.PlanningContext) ops.Operator { cfg := &evalengine.Config{ ResolveType: ctx.SemTable.TypeForExpr, Collation: ctx.SemTable.Collation, @@ -136,6 +136,7 @@ func (f *Filter) planOffsets(ctx *plancontext.PlanningContext) { } f.PredicateWithOffsets = eexpr + return nil } func (f *Filter) ShortDescription() string { diff --git a/go/vt/vtgate/planbuilder/operators/hash_join.go b/go/vt/vtgate/planbuilder/operators/hash_join.go new file mode 100644 index 00000000000..e9cfeb7d107 --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/hash_join.go @@ -0,0 +1,327 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "fmt" + "slices" + "strings" + + "vitess.io/vitess/go/slice" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" +) + +type ( + HashJoin struct { + LHS, RHS ops.Operator + + // LeftJoin will be true in the case of an outer join + LeftJoin bool + + // Before offset planning + JoinComparisons []Comparison + + // These columns are the output columns of the hash join. While in operator mode we keep track of complex expression, + // but once we move to the engine primitives, the hash join only passes through column from either left or right. + // anything more complex will be solved by a projection on top of the hash join + columns []sqlparser.Expr + + // After offset planning + + // Columns stores the column indexes of the columns coming from the left and right side + // negative value comes from LHS and positive from RHS + ColumnOffsets []int + + // These are the values that will be hashed together + LHSKeys, RHSKeys []int + + offset bool + } + + Comparison struct { + LHS, RHS sqlparser.Expr + } +) + +var _ ops.Operator = (*HashJoin)(nil) +var _ JoinOp = (*HashJoin)(nil) + +func NewHashJoin(lhs, rhs ops.Operator, outerJoin bool) *HashJoin { + hj := &HashJoin{ + LHS: lhs, + RHS: rhs, + LeftJoin: outerJoin, + } + return hj +} + +func (hj *HashJoin) Clone(inputs []ops.Operator) ops.Operator { + kopy := *hj + kopy.LHS, kopy.RHS = inputs[0], inputs[1] + kopy.columns = slices.Clone(hj.columns) + kopy.LHSKeys = slices.Clone(hj.LHSKeys) + kopy.RHSKeys = slices.Clone(hj.RHSKeys) + return &kopy +} + +func (hj *HashJoin) Inputs() []ops.Operator { + return []ops.Operator{hj.LHS, hj.RHS} +} + +func (hj *HashJoin) SetInputs(operators []ops.Operator) { + hj.LHS, hj.RHS = operators[0], operators[1] +} + +func (hj *HashJoin) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) ops.Operator { + return AddPredicate(ctx, hj, expr, false, newFilter) +} + +func (hj *HashJoin) AddColumn(ctx *plancontext.PlanningContext, reuseExisting bool, addToGroupBy bool, expr *sqlparser.AliasedExpr) int { + if reuseExisting { + offset := hj.FindCol(ctx, expr.Expr, false) + if offset >= 0 { + return offset + } + } + + hj.columns = append(hj.columns, expr.Expr) + return len(hj.columns) - 1 +} + +func (hj *HashJoin) planOffsets(ctx *plancontext.PlanningContext) ops.Operator { + if hj.offset { + return nil + } + hj.offset = true + for _, cmp := range hj.JoinComparisons { + lOffset := hj.LHS.AddColumn(ctx, true, false, aeWrap(cmp.LHS)) + hj.LHSKeys = append(hj.LHSKeys, lOffset) + rOffset := hj.RHS.AddColumn(ctx, true, false, aeWrap(cmp.RHS)) + hj.RHSKeys = append(hj.RHSKeys, rOffset) + } + + eexprs := slice.Map(hj.columns, func(in sqlparser.Expr) *ProjExpr { + return hj.addColumn(ctx, in) + }) + + proj := newAliasedProjection(hj) + _, err := proj.addProjExpr(eexprs...) + if err != nil { + panic(err) + } + + return proj +} + +func (hj *HashJoin) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, underRoute bool) int { + for offset, col := range hj.columns { + if ctx.SemTable.EqualsExprWithDeps(expr, col) { + return offset + } + } + return -1 +} + +func (hj *HashJoin) GetColumns(*plancontext.PlanningContext) []*sqlparser.AliasedExpr { + return slice.Map(hj.columns, aeWrap) +} + +func (hj *HashJoin) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.SelectExprs { + return transformColumnsToSelectExprs(ctx, hj) +} + +func (hj *HashJoin) ShortDescription() string { + comparisons := slice.Map(hj.JoinComparisons, func(from Comparison) string { + return from.String() + }) + cmp := strings.Join(comparisons, " AND ") + + if len(hj.columns) > 0 { + return fmt.Sprintf("%s columns %v", cmp, hj.columns) + } + + return cmp +} + +func (hj *HashJoin) GetOrdering(ctx *plancontext.PlanningContext) []ops.OrderBy { + return nil // hash joins will never promise an output order +} + +func (hj *HashJoin) GetLHS() ops.Operator { + return hj.LHS +} + +func (hj *HashJoin) GetRHS() ops.Operator { + return hj.RHS +} + +func (hj *HashJoin) SetLHS(op ops.Operator) { + hj.LHS = op +} + +func (hj *HashJoin) SetRHS(op ops.Operator) { + hj.RHS = op +} + +func (hj *HashJoin) MakeInner() { + hj.LeftJoin = false +} + +func (hj *HashJoin) IsInner() bool { + return !hj.LeftJoin +} + +func (hj *HashJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) { + cmp, ok := expr.(*sqlparser.ComparisonExpr) + if !ok || !canBeSolvedWithHashJoin(cmp.Operator) { + panic(vterrors.VT12001(fmt.Sprintf("can't use [%s] with hash joins", sqlparser.String(expr)))) + } + lExpr := cmp.Left + lDeps := ctx.SemTable.RecursiveDeps(lExpr) + rExpr := cmp.Right + rDeps := ctx.SemTable.RecursiveDeps(rExpr) + lID := TableID(hj.LHS) + rID := TableID(hj.RHS) + if !lDeps.IsSolvedBy(lID) || !rDeps.IsSolvedBy(rID) { + // we'll switch and see if things work out then + lExpr, rExpr = rExpr, lExpr + lDeps, rDeps = rDeps, lDeps + } + + if !lDeps.IsSolvedBy(lID) || !rDeps.IsSolvedBy(rID) { + panic(vterrors.VT12001(fmt.Sprintf("can't use [%s] with hash joins", sqlparser.String(expr)))) + } + + hj.JoinComparisons = append(hj.JoinComparisons, Comparison{ + LHS: lExpr, + RHS: rExpr, + }) +} + +func canBeSolvedWithHashJoin(op sqlparser.ComparisonExprOperator) bool { + switch op { + case sqlparser.EqualOp, sqlparser.NullSafeEqualOp: + return true + default: + return false + } +} + +func (c Comparison) String() string { + return sqlparser.String(c.LHS) + " = " + sqlparser.String(c.RHS) +} + +func (hj *HashJoin) addColumn(ctx *plancontext.PlanningContext, in sqlparser.Expr) *ProjExpr { + lId, rId := TableID(hj.LHS), TableID(hj.RHS) + var replaceExpr sqlparser.Expr // this is the expression we will put in instead of whatever we find there + pre := func(node, parent sqlparser.SQLNode) bool { + expr, ok := node.(sqlparser.Expr) + if !ok { + return true + } + deps := ctx.SemTable.RecursiveDeps(expr) + check := func(id semantics.TableSet, op ops.Operator, offsetter func(int) int) int { + if !deps.IsSolvedBy(id) { + return -1 + } + inOffset := op.FindCol(ctx, expr, false) + if inOffset == -1 { + if !fetchByOffset(expr) { + return -1 + } + + // aha! this is an expression that we have to get from the input. let's force it in there + inOffset = op.AddColumn(ctx, false, false, aeWrap(expr)) + } + + // we turn the + internalOffset := offsetter(inOffset) + + // ok, we have an offset from the input operator. Let's check if we already have it + // in our list of incoming columns + + for idx, offset := range hj.ColumnOffsets { + if internalOffset == offset { + return idx + } + } + + hj.ColumnOffsets = append(hj.ColumnOffsets, internalOffset) + + return len(hj.ColumnOffsets) - 1 + } + + f := func(i int) int { return (i * -1) - 1 } + if lOffset := check(lId, hj.LHS, f); lOffset >= 0 { + replaceExpr = sqlparser.NewOffset(lOffset, expr) + return false // we want to stop going down the expression tree and start coming back up again + } + + f = func(i int) int { return i + 1 } + if rOffset := check(rId, hj.RHS, f); rOffset >= 0 { + replaceExpr = sqlparser.NewOffset(rOffset, expr) + return false + } + + return true + } + + post := func(cursor *sqlparser.CopyOnWriteCursor) { + if replaceExpr != nil { + node := cursor.Node() + _, ok := node.(sqlparser.Expr) + if !ok { + panic(fmt.Sprintf("can't replace this node with an expression: %s", sqlparser.String(node))) + } + cursor.Replace(replaceExpr) + replaceExpr = nil + } + } + + rewrittenExpr := sqlparser.CopyOnRewrite(in, pre, post, ctx.SemTable.CopySemanticInfo).(sqlparser.Expr) + cfg := &evalengine.Config{ + ResolveType: ctx.SemTable.TypeForExpr, + Collation: ctx.SemTable.Collation, + } + eexpr, err := evalengine.Translate(rewrittenExpr, cfg) + if err != nil { + panic(err) + } + + return &ProjExpr{ + Original: aeWrap(in), + EvalExpr: rewrittenExpr, + ColExpr: rewrittenExpr, + Info: &EvalEngine{EExpr: eexpr}, + } +} + +// JoinPredicate produces an AST representation of the join condition this join has +func (hj *HashJoin) JoinPredicate() sqlparser.Expr { + exprs := slice.Map(hj.JoinComparisons, func(from Comparison) sqlparser.Expr { + return &sqlparser.ComparisonExpr{ + Left: from.LHS, + Right: from.RHS, + } + }) + return sqlparser.AndExpressions(exprs...) +} diff --git a/go/vt/vtgate/planbuilder/operators/hash_join_test.go b/go/vt/vtgate/planbuilder/operators/hash_join_test.go new file mode 100644 index 00000000000..69f21bd4b78 --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/hash_join_test.go @@ -0,0 +1,100 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" +) + +func TestJoinPredicates(t *testing.T) { + lcol := sqlparser.NewColName("lhs") + rcol := sqlparser.NewColName("rhs") + ctx := &plancontext.PlanningContext{SemTable: semantics.EmptySemTable()} + lid := semantics.SingleTableSet(0) + rid := semantics.SingleTableSet(1) + ctx.SemTable.Recursive[lcol] = lid + ctx.SemTable.Recursive[rcol] = rid + lhs := &fakeOp{id: lid} + rhs := &fakeOp{id: rid} + hj := &HashJoin{ + LHS: lhs, + RHS: rhs, + LeftJoin: false, + } + + cmp := &sqlparser.ComparisonExpr{ + Operator: sqlparser.EqualOp, + Left: lcol, + Right: rcol, + } + hj.AddJoinPredicate(ctx, cmp) + require.Len(t, hj.JoinComparisons, 1) + hj.planOffsets(ctx) + require.Len(t, hj.LHSKeys, 1) + require.Len(t, hj.RHSKeys, 1) +} + +func TestOffsetPlanning(t *testing.T) { + lcol1, lcol2 := sqlparser.NewColName("lhs1"), sqlparser.NewColName("lhs2") + rcol1, rcol2 := sqlparser.NewColName("rhs1"), sqlparser.NewColName("rhs2") + ctx := &plancontext.PlanningContext{SemTable: semantics.EmptySemTable()} + lid := semantics.SingleTableSet(0) + rid := semantics.SingleTableSet(1) + ctx.SemTable.Recursive[lcol1] = lid + ctx.SemTable.Recursive[lcol2] = lid + ctx.SemTable.Recursive[rcol1] = rid + ctx.SemTable.Recursive[rcol2] = rid + lhs := &fakeOp{id: lid} + rhs := &fakeOp{id: rid} + + tests := []struct { + expr sqlparser.Expr + expectedColOffsets []int + }{{ + expr: lcol1, + expectedColOffsets: []int{-1}, + }, { + expr: rcol1, + expectedColOffsets: []int{1}, + }, { + expr: sqlparser.AndExpressions(lcol1, lcol2), + expectedColOffsets: []int{-1, -2}, + }, { + expr: sqlparser.AndExpressions(lcol1, rcol1, lcol2, rcol2), + expectedColOffsets: []int{-1, 1, -2, 2}, + }} + + for _, test := range tests { + t.Run(sqlparser.String(test.expr), func(t *testing.T) { + hj := &HashJoin{ + LHS: lhs, + RHS: rhs, + LeftJoin: false, + } + hj.AddColumn(ctx, true, false, aeWrap(test.expr)) + hj.planOffsets(ctx) + assert.Equal(t, test.expectedColOffsets, hj.ColumnOffsets) + }) + } +} diff --git a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go index 7ec141a1b8b..06bcf2aaeb5 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go +++ b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go @@ -72,10 +72,10 @@ func expandUnionHorizon(ctx *plancontext.PlanningContext, horizon *Horizon, unio } if op == horizon.Source { - return op, rewrite.NewTree("removed UNION horizon not used", op), nil + return op, rewrite.NewTree("removed UNION horizon not used"), nil } - return op, rewrite.NewTree("expand UNION horizon into smaller components", op), nil + return op, rewrite.NewTree("expand UNION horizon into smaller components"), nil } func expandSelectHorizon(ctx *plancontext.PlanningContext, horizon *Horizon, sel *sqlparser.Select) (ops.Operator, *rewrite.ApplyResult, error) { @@ -126,7 +126,7 @@ func expandSelectHorizon(ctx *plancontext.PlanningContext, horizon *Horizon, sel extracted = append(extracted, "Limit") } - return op, rewrite.NewTree(fmt.Sprintf("expand SELECT horizon into (%s)", strings.Join(extracted, ", ")), op), nil + return op, rewrite.NewTree(fmt.Sprintf("expand SELECT horizon into (%s)", strings.Join(extracted, ", "))), nil } func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon *Horizon) (out ops.Operator) { diff --git a/go/vt/vtgate/planbuilder/operators/join.go b/go/vt/vtgate/planbuilder/operators/join.go index 828b15f5b79..1d50a688df4 100644 --- a/go/vt/vtgate/planbuilder/operators/join.go +++ b/go/vt/vtgate/planbuilder/operators/join.go @@ -82,7 +82,7 @@ func (j *Join) Compact(ctx *plancontext.PlanningContext) (ops.Operator, *rewrite if j.Predicate != nil { newOp.collectPredicate(ctx, j.Predicate) } - return newOp, rewrite.NewTree("merge querygraphs into a single one", newOp), nil + return newOp, rewrite.NewTree("merge querygraphs into a single one"), nil } func createOuterJoin(tableExpr *sqlparser.JoinTableExpr, lhs, rhs ops.Operator) (ops.Operator, error) { @@ -162,9 +162,8 @@ func (j *Join) IsInner() bool { return !j.LeftJoin } -func (j *Join) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) error { +func (j *Join) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) { j.Predicate = ctx.SemTable.AndExpressions(j.Predicate, expr) - return nil } func (j *Join) ShortDescription() string { diff --git a/go/vt/vtgate/planbuilder/operators/joins.go b/go/vt/vtgate/planbuilder/operators/joins.go index 3b5c31c5dce..ad61a6c5a00 100644 --- a/go/vt/vtgate/planbuilder/operators/joins.go +++ b/go/vt/vtgate/planbuilder/operators/joins.go @@ -31,7 +31,7 @@ type JoinOp interface { SetRHS(ops.Operator) MakeInner() IsInner() bool - AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) error + AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) } func AddPredicate( @@ -79,10 +79,7 @@ func AddPredicate( return newFilter(join, expr) } - err := join.AddJoinPredicate(ctx, expr) - if err != nil { - panic(err) - } + join.AddJoinPredicate(ctx, expr) return join } diff --git a/go/vt/vtgate/planbuilder/operators/offset_planning.go b/go/vt/vtgate/planbuilder/operators/offset_planning.go index 7e7be49874a..d2fc266790c 100644 --- a/go/vt/vtgate/planbuilder/operators/offset_planning.go +++ b/go/vt/vtgate/planbuilder/operators/offset_planning.go @@ -30,7 +30,7 @@ import ( // planOffsets will walk the tree top down, adding offset information to columns in the tree for use in further optimization, func planOffsets(ctx *plancontext.PlanningContext, root ops.Operator) (ops.Operator, error) { type offsettable interface { - planOffsets(ctx *plancontext.PlanningContext) + planOffsets(ctx *plancontext.PlanningContext) ops.Operator } visitor := func(in ops.Operator, _ semantics.TableSet, _ bool) (ops.Operator, *rewrite.ApplyResult, error) { @@ -39,7 +39,10 @@ func planOffsets(ctx *plancontext.PlanningContext, root ops.Operator) (ops.Opera case *Horizon: return nil, nil, vterrors.VT13001(fmt.Sprintf("should not see %T here", in)) case offsettable: - op.planOffsets(ctx) + newOp := op.planOffsets(ctx) + if newOp != nil { + return newOp, rewrite.NewTree("new operator after offset planning"), nil + } } if err != nil { return nil, nil, err @@ -116,7 +119,7 @@ func addColumnsToInput(ctx *plancontext.PlanningContext, root ops.Operator) (ops _ = sqlparser.CopyOnRewrite(expr, visitor, nil, ctx.SemTable.CopySemanticInfo) } if addedColumns { - return in, rewrite.NewTree("added columns because filter needs it", in), nil + return in, rewrite.NewTree("added columns because filter needs it"), nil } return in, rewrite.SameTree, nil @@ -140,7 +143,7 @@ func pullDistinctFromUNION(_ *plancontext.PlanningContext, root ops.Operator) (o Required: true, Source: union, } - return distinct, rewrite.NewTree("pulled out DISTINCT from union", union), nil + return distinct, rewrite.NewTree("pulled out DISTINCT from union"), nil } return rewrite.TopDown(root, TableID, visitor, stopAtRoute) diff --git a/go/vt/vtgate/planbuilder/operators/ordering.go b/go/vt/vtgate/planbuilder/operators/ordering.go index b3d0310eadb..66436f6a47d 100644 --- a/go/vt/vtgate/planbuilder/operators/ordering.go +++ b/go/vt/vtgate/planbuilder/operators/ordering.go @@ -78,7 +78,7 @@ func (o *Ordering) GetOrdering(*plancontext.PlanningContext) []ops.OrderBy { return o.Order } -func (o *Ordering) planOffsets(ctx *plancontext.PlanningContext) { +func (o *Ordering) planOffsets(ctx *plancontext.PlanningContext) ops.Operator { for _, order := range o.Order { offset := o.Source.AddColumn(ctx, true, false, aeWrap(order.SimplifiedExpr)) o.Offset = append(o.Offset, offset) @@ -92,6 +92,7 @@ func (o *Ordering) planOffsets(ctx *plancontext.PlanningContext) { offset = o.Source.AddColumn(ctx, true, false, aeWrap(wsExpr)) o.WOffset = append(o.WOffset, offset) } + return nil } func (o *Ordering) ShortDescription() string { diff --git a/go/vt/vtgate/planbuilder/operators/phases.go b/go/vt/vtgate/planbuilder/operators/phases.go index 36149f89985..f180cfc2ce0 100644 --- a/go/vt/vtgate/planbuilder/operators/phases.go +++ b/go/vt/vtgate/planbuilder/operators/phases.go @@ -96,15 +96,18 @@ func (p Phase) act(ctx *plancontext.PlanningContext, op ops.Operator) (ops.Opera } } -// getPhases returns the ordered phases that the planner will undergo. -// These phases ensure the appropriate collaboration between rewriters. -func getPhases(ctx *plancontext.PlanningContext) (phases []Phase) { - for p := Phase(0); p < DONE; p++ { - if p.shouldRun(ctx.SemTable.QuerySignature) { - phases = append(phases, p) +type phaser struct { + current Phase +} + +func (p *phaser) next(ctx *plancontext.PlanningContext) Phase { + for phas := p.current; phas < DONE; phas++ { + if phas.shouldRun(ctx.SemTable.QuerySignature) { + p.current = p.current + 1 + return phas } } - return + return DONE } func removePerformanceDistinctAboveRoute(_ *plancontext.PlanningContext, op ops.Operator) (ops.Operator, error) { @@ -114,7 +117,7 @@ func removePerformanceDistinctAboveRoute(_ *plancontext.PlanningContext, op ops. return innerOp, rewrite.SameTree, nil } - return d.Source, rewrite.NewTree("removed distinct not required that was not pushed under route", d), nil + return d.Source, rewrite.NewTree("removed distinct not required that was not pushed under route"), nil }, stopAtRoute) } @@ -138,7 +141,7 @@ func addOrderingForAllAggregations(ctx *plancontext.PlanningContext, root ops.Op var res *rewrite.ApplyResult if requireOrdering { addOrderingFor(aggrOp) - res = rewrite.NewTree("added ordering before aggregation", in) + res = rewrite.NewTree("added ordering before aggregation") } return in, res, nil } diff --git a/go/vt/vtgate/planbuilder/operators/projection.go b/go/vt/vtgate/planbuilder/operators/projection.go index bad9cb0e87e..7e9f2d71a71 100644 --- a/go/vt/vtgate/planbuilder/operators/projection.go +++ b/go/vt/vtgate/planbuilder/operators/projection.go @@ -263,14 +263,14 @@ func (p *Projection) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Ex return -1 } -func (p *Projection) addProjExpr(pe *ProjExpr) (int, error) { +func (p *Projection) addProjExpr(pe ...*ProjExpr) (int, error) { ap, err := p.GetAliasedProjections() if err != nil { return 0, err } offset := len(ap) - ap = append(ap, pe) + ap = append(ap, pe...) p.Columns = ap return offset, nil @@ -471,7 +471,7 @@ func (p *Projection) Compact(ctx *plancontext.PlanningContext) (ops.Operator, *r } if !needed { - return p.Source, rewrite.NewTree("removed projection only passing through the input", p), nil + return p.Source, rewrite.NewTree("removed projection only passing through the input"), nil } switch src := p.Source.(type) { @@ -517,7 +517,7 @@ func (p *Projection) compactWithJoin(ctx *plancontext.PlanningContext, join *App } join.Columns = newColumns join.JoinColumns = newColumnsAST - return join, rewrite.NewTree("remove projection from before join", join), nil + return join, rewrite.NewTree("remove projection from before join"), nil } func (p *Projection) compactWithRoute(ctx *plancontext.PlanningContext, rb *Route) (ops.Operator, *rewrite.ApplyResult, error) { @@ -538,7 +538,7 @@ func (p *Projection) compactWithRoute(ctx *plancontext.PlanningContext, rb *Rout } if len(columns) == len(ap) { - return rb, rewrite.NewTree("remove projection from before route", rb), nil + return rb, rewrite.NewTree("remove projection from before route"), nil } rb.ResultColumns = len(columns) return rb, rewrite.SameTree, nil @@ -561,7 +561,7 @@ func (p *Projection) needsEvaluation(ctx *plancontext.PlanningContext, e sqlpars return false } -func (p *Projection) planOffsets(ctx *plancontext.PlanningContext) { +func (p *Projection) planOffsets(ctx *plancontext.PlanningContext) ops.Operator { ap, err := p.GetAliasedProjections() if err != nil { panic(err) @@ -597,6 +597,7 @@ func (p *Projection) planOffsets(ctx *plancontext.PlanningContext) { EExpr: eexpr, } } + return nil } func (p *Projection) introducesTableID() semantics.TableSet { diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index e66abd02feb..1f239a31973 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -65,7 +65,9 @@ func planQuery(ctx *plancontext.PlanningContext, root ops.Operator) (output ops. // smaller operators and try to push these down as far as possible func runPhases(ctx *plancontext.PlanningContext, root ops.Operator) (op ops.Operator, err error) { op = root - for _, phase := range getPhases(ctx) { + + p := phaser{} + for phase := p.next(ctx); phase != DONE; phase = p.next(ctx) { ctx.CurrentPhase = int(phase) if rewrite.DebugOperatorTree { fmt.Printf("PHASE: %s\n", phase.String()) @@ -134,7 +136,7 @@ func pushLockAndComment(l *LockAndComment) (ops.Operator, *rewrite.ApplyResult, case *Route: src.Comments = l.Comments src.Lock = l.Lock - return src, rewrite.NewTree("put lock and comment into route", l), nil + return src, rewrite.NewTree("put lock and comment into route"), nil default: inputs := src.Inputs() for i, op := range inputs { @@ -145,7 +147,7 @@ func pushLockAndComment(l *LockAndComment) (ops.Operator, *rewrite.ApplyResult, } } src.SetInputs(inputs) - return src, rewrite.NewTree("pushed down lock and comments", l), nil + return src, rewrite.NewTree("pushed down lock and comments"), nil } } @@ -256,7 +258,7 @@ func pushProjectionToOuter(ctx *plancontext.PlanningContext, p *Projection, sq * } // all projections can be pushed to the outer sq.Outer, p.Source = p, sq.Outer - return sq, rewrite.NewTree("push projection into outer side of subquery", p), nil + return sq, rewrite.NewTree("push projection into outer side of subquery"), nil } func pushProjectionInVindex( @@ -271,7 +273,7 @@ func pushProjectionInVindex( for _, pe := range ap { src.AddColumn(ctx, true, false, aeWrap(pe.EvalExpr)) } - return src, rewrite.NewTree("push projection into vindex", p), nil + return src, rewrite.NewTree("push projection into vindex"), nil } func (p *projector) add(pe *ProjExpr, col *sqlparser.IdentifierCI) { @@ -331,7 +333,7 @@ func pushProjectionInApplyJoin( return nil, nil, err } - return src, rewrite.NewTree("split projection to either side of join", src), nil + return src, rewrite.NewTree("split projection to either side of join"), nil } // splitProjectionAcrossJoin creates JoinPredicates for all projections, @@ -521,7 +523,7 @@ func setUpperLimit(in *Limit) (ops.Operator, *rewrite.ApplyResult, error) { Pushed: false, } op.Source = newSrc - result = result.Merge(rewrite.NewTree("push limit under route", newSrc)) + result = result.Merge(rewrite.NewTree("push limit under route")) return rewrite.SkipChildren default: return rewrite.VisitChildren @@ -544,12 +546,12 @@ func tryPushOrdering(ctx *plancontext.PlanningContext, in *Ordering) (ops.Operat // ApplyJoin is stable in regard to the columns coming from the LHS, // so if all the ordering columns come from the LHS, we can push down the Ordering there src.LHS, in.Source = in, src.LHS - return src, rewrite.NewTree("push down ordering on the LHS of a join", in), nil + return src, rewrite.NewTree("push down ordering on the LHS of a join"), nil } case *Ordering: // we'll just remove the order underneath. The top order replaces whatever was incoming in.Source = src.Source - return in, rewrite.NewTree("remove double ordering", src), nil + return in, rewrite.NewTree("remove double ordering"), nil case *Projection: // we can move ordering under a projection if it's not introducing a column we're sorting by for _, by := range in.Order { @@ -573,7 +575,7 @@ func tryPushOrdering(ctx *plancontext.PlanningContext, in *Ordering) (ops.Operat } } src.Outer, in.Source = in, src.Outer - return src, rewrite.NewTree("push ordering into outer side of subquery", in), nil + return src, rewrite.NewTree("push ordering into outer side of subquery"), nil case *SubQuery: outerTableID := TableID(src.Outer) for _, order := range in.Order { @@ -583,7 +585,7 @@ func tryPushOrdering(ctx *plancontext.PlanningContext, in *Ordering) (ops.Operat } } src.Outer, in.Source = in, src.Outer - return src, rewrite.NewTree("push ordering into outer side of subquery", in), nil + return src, rewrite.NewTree("push ordering into outer side of subquery"), nil } return in, rewrite.SameTree, nil } @@ -662,7 +664,7 @@ func pushOrderingUnderAggr(ctx *plancontext.PlanningContext, order *Ordering, ag order.Source = aggrSource.Source aggrSource.Source = nil // removing from plan tree aggregator.Source = order - return aggregator, rewrite.NewTree("push ordering under aggregation, removing extra ordering", aggregator), nil + return aggregator, rewrite.NewTree("push ordering under aggregation, removing extra ordering"), nil } return rewrite.Swap(order, aggregator, "push ordering under aggregation") } @@ -719,7 +721,7 @@ func tryPushFilter(ctx *plancontext.PlanningContext, in *Filter) (ops.Operator, } } src.Outer, in.Source = in, src.Outer - return src, rewrite.NewTree("push filter to outer query in subquery container", in), nil + return src, rewrite.NewTree("push filter to outer query in subquery container"), nil } return in, rewrite.SameTree, nil @@ -756,7 +758,7 @@ func tryPushDistinct(in *Distinct) (ops.Operator, *rewrite.ApplyResult, error) { switch src := in.Source.(type) { case *Route: if isDistinct(src.Source) && src.IsSingleShard() { - return src, rewrite.NewTree("distinct not needed", in), nil + return src, rewrite.NewTree("distinct not needed"), nil } if src.IsSingleShard() || !in.Required { return rewrite.Swap(in, src, "push distinct under route") @@ -769,31 +771,31 @@ func tryPushDistinct(in *Distinct) (ops.Operator, *rewrite.ApplyResult, error) { src.Source = &Distinct{Source: src.Source} in.PushedPerformance = true - return in, rewrite.NewTree("added distinct under route - kept original", src), nil + return in, rewrite.NewTree("added distinct under route - kept original"), nil case *Distinct: src.Required = false src.PushedPerformance = false - return src, rewrite.NewTree("remove double distinct", src), nil + return src, rewrite.NewTree("remove double distinct"), nil case *Union: for i := range src.Sources { src.Sources[i] = &Distinct{Source: src.Sources[i]} } in.PushedPerformance = true - return in, rewrite.NewTree("push down distinct under union", src), nil + return in, rewrite.NewTree("push down distinct under union"), nil case *ApplyJoin: src.LHS = &Distinct{Source: src.LHS} src.RHS = &Distinct{Source: src.RHS} in.PushedPerformance = true if in.Required { - return in, rewrite.NewTree("push distinct under join - kept original", in.Source), nil + return in, rewrite.NewTree("push distinct under join - kept original"), nil } - return in.Source, rewrite.NewTree("push distinct under join", in.Source), nil + return in.Source, rewrite.NewTree("push distinct under join"), nil case *Ordering: in.Source = src.Source - return in, rewrite.NewTree("remove ordering under distinct", in), nil + return in, rewrite.NewTree("remove ordering under distinct"), nil } return in, rewrite.SameTree, nil @@ -835,19 +837,19 @@ func tryPushUnion(ctx *plancontext.PlanningContext, op *Union) (ops.Operator, *r if len(sources) == 1 { result := sources[0].(*Route) if result.IsSingleShard() || !op.distinct { - return result, rewrite.NewTree("push union under route", op), nil + return result, rewrite.NewTree("push union under route"), nil } return &Distinct{ Source: result, Required: true, - }, rewrite.NewTree("push union under route", op), nil + }, rewrite.NewTree("push union under route"), nil } if len(sources) == len(op.Sources) { return op, rewrite.SameTree, nil } - return newUnion(sources, selects, op.unionColumns, op.distinct), rewrite.NewTree("merge union inputs", op), nil + return newUnion(sources, selects, op.unionColumns, op.distinct), rewrite.NewTree("merge union inputs"), nil } // addTruncationOrProjectionToReturnOutput uses the original Horizon to make sure that the output columns line up with what the user asked for diff --git a/go/vt/vtgate/planbuilder/operators/rewrite/rewriters.go b/go/vt/vtgate/planbuilder/operators/rewrite/rewriters.go index c5a8b0a6fa2..1ecc0cd8e76 100644 --- a/go/vt/vtgate/planbuilder/operators/rewrite/rewriters.go +++ b/go/vt/vtgate/planbuilder/operators/rewrite/rewriters.go @@ -45,7 +45,6 @@ type ( Rewrite struct { Message string - Op ops.Operator } // VisitRule signals to the rewriter if the children of this operator should be visited or not @@ -61,11 +60,11 @@ const ( SkipChildren VisitRule = false ) -func NewTree(message string, op ops.Operator) *ApplyResult { +func NewTree(message string) *ApplyResult { if DebugOperatorTree { fmt.Println(">>>>>>>> " + message) } - return &ApplyResult{Transformations: []Rewrite{{Message: message, Op: op}}} + return &ApplyResult{Transformations: []Rewrite{{Message: message}}} } func (ar *ApplyResult) Merge(other *ApplyResult) *ApplyResult { @@ -145,19 +144,6 @@ func FixedPointBottomUp( return op, nil } -// BottomUpAll rewrites an operator tree from the bottom up. BottomUp applies a transformation function to -// the given operator tree from the bottom up. Each callback [f] returns a ApplyResult that is aggregated -// into a final output indicating whether the operator tree was changed. -func BottomUpAll( - root ops.Operator, - resolveID func(ops.Operator) semantics.TableSet, - visit VisitF, -) (ops.Operator, error) { - return BottomUp(root, resolveID, visit, func(ops.Operator) VisitRule { - return VisitChildren - }) -} - // TopDown rewrites an operator tree from the bottom up. BottomUp applies a transformation function to // the given operator tree from the bottom up. Each callback [f] returns a ApplyResult that is aggregated // into a final output indicating whether the operator tree was changed. @@ -208,7 +194,7 @@ func Swap(parent, child ops.Operator, message string) (ops.Operator, *ApplyResul child.SetInputs([]ops.Operator{parent}) parent.SetInputs(aInputs) - return child, NewTree(message, parent), nil + return child, NewTree(message), nil } func bottomUp( diff --git a/go/vt/vtgate/planbuilder/operators/route.go b/go/vt/vtgate/planbuilder/operators/route.go index d4b2c43ecff..acbc28553dd 100644 --- a/go/vt/vtgate/planbuilder/operators/route.go +++ b/go/vt/vtgate/planbuilder/operators/route.go @@ -681,17 +681,17 @@ func isSpecialOrderBy(o ops.OrderBy) bool { return isFunction && f.Name.Lowered() == "rand" } -func (r *Route) planOffsets(ctx *plancontext.PlanningContext) { +func (r *Route) planOffsets(ctx *plancontext.PlanningContext) ops.Operator { // if operator is returning data from a single shard, we don't need to do anything more if r.IsSingleShard() { - return + return nil } // if we are getting results from multiple shards, we need to do a merge-sort // between them to get the final output correctly sorted ordering := r.Source.GetOrdering(ctx) if len(ordering) == 0 { - return + return nil } for _, order := range ordering { @@ -713,6 +713,7 @@ func (r *Route) planOffsets(ctx *plancontext.PlanningContext) { } r.Ordering = append(r.Ordering, o) } + return nil } func weightStringFor(expr sqlparser.Expr) sqlparser.Expr { diff --git a/go/vt/vtgate/planbuilder/operators/route_planning.go b/go/vt/vtgate/planbuilder/operators/route_planning.go index 306158a06da..5bce334a609 100644 --- a/go/vt/vtgate/planbuilder/operators/route_planning.go +++ b/go/vt/vtgate/planbuilder/operators/route_planning.go @@ -75,7 +75,7 @@ func optimizeQueryGraph(ctx *plancontext.PlanningContext, op *QueryGraph) (resul result = newFilter(result, ctx.SemTable.AndExpressions(unresolved...)) } - changed = rewrite.NewTree("solved query graph", result) + changed = rewrite.NewTree("solved query graph") return } @@ -365,16 +365,18 @@ func requiresSwitchingSides(ctx *plancontext.PlanningContext, op ops.Operator) b func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs ops.Operator, joinPredicates []sqlparser.Expr, inner bool) (ops.Operator, *rewrite.ApplyResult, error) { newPlan := mergeJoinInputs(ctx, lhs, rhs, joinPredicates, newJoinMerge(joinPredicates, inner)) if newPlan != nil { - return newPlan, rewrite.NewTree("merge routes into single operator", newPlan), nil + return newPlan, rewrite.NewTree("merge routes into single operator"), nil } if len(joinPredicates) > 0 && requiresSwitchingSides(ctx, rhs) { - if !inner { - return nil, nil, vterrors.VT12001("LEFT JOIN with LIMIT on the outer side") - } - - if requiresSwitchingSides(ctx, lhs) { - return nil, nil, vterrors.VT12001("JOIN between derived tables with LIMIT") + if !inner || requiresSwitchingSides(ctx, lhs) { + // we can't switch sides, so let's see if we can use a HashJoin to solve it + join := NewHashJoin(lhs, rhs, !inner) + for _, pred := range joinPredicates { + join.AddJoinPredicate(ctx, pred) + } + ctx.SemTable.QuerySignature.HashJoin = true + return join, rewrite.NewTree("use a hash join because we have LIMIT on the LHS"), nil } join := NewApplyJoin(Clone(rhs), Clone(lhs), nil, !inner) @@ -382,7 +384,7 @@ func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs ops.Operator, joinPr if err != nil { return nil, nil, err } - return newOp, rewrite.NewTree("logical join to applyJoin, switching side because LIMIT", newOp), nil + return newOp, rewrite.NewTree("logical join to applyJoin, switching side because LIMIT"), nil } join := NewApplyJoin(Clone(lhs), Clone(rhs), nil, !inner) @@ -390,7 +392,7 @@ func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs ops.Operator, joinPr if err != nil { return nil, nil, err } - return newOp, rewrite.NewTree("logical join to applyJoin ", newOp), nil + return newOp, rewrite.NewTree("logical join to applyJoin "), nil } func operatorsToRoutes(a, b ops.Operator) (*Route, *Route) { diff --git a/go/vt/vtgate/planbuilder/operators/subquery.go b/go/vt/vtgate/planbuilder/operators/subquery.go index e06e595f689..ae28dd8d9c6 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery.go +++ b/go/vt/vtgate/planbuilder/operators/subquery.go @@ -54,7 +54,7 @@ type SubQuery struct { IsProjection bool } -func (sq *SubQuery) planOffsets(ctx *plancontext.PlanningContext) { +func (sq *SubQuery) planOffsets(ctx *plancontext.PlanningContext) ops.Operator { sq.Vars = make(map[string]int) columns, err := sq.GetJoinColumns(ctx, sq.Outer) if err != nil { @@ -66,6 +66,7 @@ func (sq *SubQuery) planOffsets(ctx *plancontext.PlanningContext) { sq.Vars[lhsExpr.Name] = offset } } + return nil } func (sq *SubQuery) OuterExpressionsNeeded(ctx *plancontext.PlanningContext, outer ops.Operator) (result []*sqlparser.ColName, err error) { @@ -97,7 +98,7 @@ func (sq *SubQuery) GetJoinColumns(ctx *plancontext.PlanningContext, outer ops.O } sq.outerID = outerID mapper := func(in sqlparser.Expr) (JoinColumn, error) { - return BreakExpressionInLHSandRHS(ctx, in, outerID) + return breakExpressionInLHSandRHSForApplyJoin(ctx, in, outerID) } joinPredicates, err := slice.MapWithError(sq.Predicates, mapper) if err != nil { diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index 7740ca3d46d..74761aef5c5 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -87,7 +87,7 @@ func settleSubqueries(ctx *plancontext.PlanningContext, op ops.Operator) ops.Ope subq.Outer = newOuter outer = subq } - return outer, rewrite.NewTree("extracted subqueries from subquery container", outer), nil + return outer, rewrite.NewTree("extracted subqueries from subquery container"), nil case *Projection: ap, err := op.GetAliasedProjections() if err != nil { @@ -233,7 +233,7 @@ func tryPushSubQueryInJoin( if deps.IsSolvedBy(lhs) { // we can safely push down the subquery on the LHS outer.LHS = addSubQuery(outer.LHS, inner) - return outer, rewrite.NewTree("push subquery into LHS of join", inner), nil + return outer, rewrite.NewTree("push subquery into LHS of join"), nil } if outer.LeftJoin || len(inner.Predicates) == 0 { @@ -245,7 +245,7 @@ func tryPushSubQueryInJoin( if deps.IsSolvedBy(rhs) { // we can push down the subquery filter on RHS of the join outer.RHS = addSubQuery(outer.RHS, inner) - return outer, rewrite.NewTree("push subquery into RHS of join", inner), nil + return outer, rewrite.NewTree("push subquery into RHS of join"), nil } if deps.IsSolvedBy(joinID) { @@ -258,7 +258,7 @@ func tryPushSubQueryInJoin( } outer.RHS = addSubQuery(outer.RHS, inner) - return outer, rewrite.NewTree("push subquery into RHS of join rewriting predicates", inner), nil + return outer, rewrite.NewTree("push subquery into RHS of join rewriting predicates"), nil } return nil, rewrite.SameTree, nil @@ -272,7 +272,7 @@ func extractLHSExpr( lhs semantics.TableSet, ) func(expr sqlparser.Expr) (sqlparser.Expr, error) { return func(expr sqlparser.Expr) (sqlparser.Expr, error) { - col, err := BreakExpressionInLHSandRHS(ctx, expr, lhs) + col, err := breakExpressionInLHSandRHSForApplyJoin(ctx, expr, lhs) if err != nil { return nil, err } @@ -316,7 +316,7 @@ func tryMergeWithRHS(ctx *plancontext.PlanningContext, inner *SubQuery, outer *A outer.RHS = newOp ctx.MergedSubqueries = append(ctx.MergedSubqueries, inner.originalSubquery) - return outer, rewrite.NewTree("merged subquery with rhs of join", inner), nil + return outer, rewrite.NewTree("merged subquery with rhs of join"), nil } // addSubQuery adds a SubQuery to the given operator. If the operator is a SubQueryContainer, @@ -387,7 +387,7 @@ func pushProjectionToOuterContainer(ctx *plancontext.PlanningContext, p *Project } // all projections can be pushed to the outer src.Outer, p.Source = p, src.Outer - return src, rewrite.NewTree("push projection into outer side of subquery container", p), nil + return src, rewrite.NewTree("push projection into outer side of subquery container"), nil } func rewriteColNameToArgument(ctx *plancontext.PlanningContext, in sqlparser.Expr, se SubQueryExpression, subqueries ...*SubQuery) sqlparser.Expr { @@ -512,7 +512,7 @@ func tryMergeSubqueriesRecursively( } op.Source = &Filter{Source: outer.Source, Predicates: []sqlparser.Expr{subQuery.Original}} - return op, finalResult.Merge(rewrite.NewTree("merge outer of two subqueries", subQuery)), nil + return op, finalResult.Merge(rewrite.NewTree("merge outer of two subqueries")), nil } func tryMergeSubqueryWithOuter(ctx *plancontext.PlanningContext, subQuery *SubQuery, outer *Route, inner ops.Operator) (ops.Operator, *rewrite.ApplyResult, error) { @@ -533,7 +533,7 @@ func tryMergeSubqueryWithOuter(ctx *plancontext.PlanningContext, subQuery *SubQu op.Source = &Filter{Source: outer.Source, Predicates: []sqlparser.Expr{subQuery.Original}} } ctx.MergedSubqueries = append(ctx.MergedSubqueries, subQuery.originalSubquery) - return op, rewrite.NewTree("merged subquery with outer", subQuery), nil + return op, rewrite.NewTree("merged subquery with outer"), nil } // This checked if subquery is part of the changed vindex values. Subquery cannot be merged with the outer route. diff --git a/go/vt/vtgate/planbuilder/operators/union_merging.go b/go/vt/vtgate/planbuilder/operators/union_merging.go index 4c8b02f76d8..03b7a212893 100644 --- a/go/vt/vtgate/planbuilder/operators/union_merging.go +++ b/go/vt/vtgate/planbuilder/operators/union_merging.go @@ -255,5 +255,5 @@ func compactUnion(u *Union) *rewrite.ApplyResult { u.Sources = newSources u.Selects = newSelects - return rewrite.NewTree("merged UNIONs", u) + return rewrite.NewTree("merged UNIONs") } diff --git a/go/vt/vtgate/planbuilder/operators/utils_test.go b/go/vt/vtgate/planbuilder/operators/utils_test.go new file mode 100644 index 00000000000..a7e25c1337c --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/utils_test.go @@ -0,0 +1,90 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "slices" + + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" +) + +type fakeOp struct { + id semantics.TableSet + inputs []ops.Operator + cols []*sqlparser.AliasedExpr +} + +var _ ops.Operator = (*fakeOp)(nil) + +func (f *fakeOp) Clone(inputs []ops.Operator) ops.Operator { + return f +} + +func (f *fakeOp) Inputs() []ops.Operator { + return f.inputs +} + +func (f *fakeOp) SetInputs(operators []ops.Operator) { + // TODO implement me + panic("implement me") +} + +func (f *fakeOp) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) ops.Operator { + // TODO implement me + panic("implement me") +} + +func (f *fakeOp) AddColumn(ctx *plancontext.PlanningContext, reuseExisting bool, _ bool, expr *sqlparser.AliasedExpr) int { + if offset := f.FindCol(ctx, expr.Expr, false); reuseExisting && offset >= 0 { + return offset + } + f.cols = append(f.cols, expr) + return len(f.cols) - 1 +} + +func (f *fakeOp) FindCol(ctx *plancontext.PlanningContext, a sqlparser.Expr, underRoute bool) int { + return slices.IndexFunc(f.cols, func(b *sqlparser.AliasedExpr) bool { + return a == b.Expr + }) +} + +func (f *fakeOp) GetColumns(ctx *plancontext.PlanningContext) []*sqlparser.AliasedExpr { + // TODO implement me + panic("implement me") +} + +func (f *fakeOp) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.SelectExprs { + // TODO implement me + panic("implement me") +} + +func (f *fakeOp) ShortDescription() string { + // TODO implement me + panic("implement me") +} + +func (f *fakeOp) GetOrdering(ctx *plancontext.PlanningContext) []ops.OrderBy { + // TODO implement me + panic("implement me") +} + +func (f *fakeOp) introducesTableID() semantics.TableSet { + return f.id +} diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.json b/go/vt/vtgate/planbuilder/testdata/from_cases.json index 7a3c13c3635..476a16f8a11 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.json @@ -4050,5 +4050,130 @@ "comment": "select with a target destination", "query": "select * from `user[-]`.user_metadata", "plan": "VT09017: SELECT with a target destination is not allowed" + }, + { + "comment": "query that needs a hash join", + "query": "select id from user left join (select col from user_extra limit 10) ue on user.col = ue.col", + "plan": { + "QueryType": "SELECT", + "Original": "select id from user left join (select col from user_extra limit 10) ue on user.col = ue.col", + "Instructions": { + "OperatorType": "Projection", + "Expressions": [ + "id as id" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "HashLeftJoin", + "Collation": "binary", + "ComparisonType": "INT16", + "JoinColumnIndexes": "-2", + "Predicate": "`user`.col = ue.col", + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.col, id from `user` where 1 != 1", + "Query": "select `user`.col, id from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Limit", + "Count": "10", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col from (select col from user_extra where 1 != 1) as ue where 1 != 1", + "Query": "select col from (select col from user_extra) as ue limit :__upper_limit", + "Table": "user_extra" + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] + } + }, + { + "comment": "query that needs a hash join - both sides have limits", + "query": "select id, user_id from (select id, col from user limit 10) u join (select col, user_id from user_extra limit 10) ue on u.col = ue.col", + "plan": { + "QueryType": "SELECT", + "Original": "select id, user_id from (select id, col from user limit 10) u join (select col, user_id from user_extra limit 10) ue on u.col = ue.col", + "Instructions": { + "OperatorType": "Projection", + "Expressions": [ + "id as id", + "user_id as user_id" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "HashJoin", + "Collation": "binary", + "ComparisonType": "INT16", + "JoinColumnIndexes": "-1,2", + "Predicate": "u.col = ue.col", + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Limit", + "Count": "10", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id, col from (select id, col from `user` where 1 != 1) as u where 1 != 1", + "Query": "select id, col from (select id, col from `user`) as u limit :__upper_limit", + "Table": "`user`" + } + ] + }, + { + "OperatorType": "Limit", + "Count": "10", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select col, user_id from (select col, user_id from user_extra where 1 != 1) as ue where 1 != 1", + "Query": "select col, user_id from (select col, user_id from user_extra) as ue limit :__upper_limit", + "Table": "user_extra" + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/hash_joins.txt b/go/vt/vtgate/planbuilder/testdata/hash_joins.txt deleted file mode 100644 index afc175a581c..00000000000 --- a/go/vt/vtgate/planbuilder/testdata/hash_joins.txt +++ /dev/null @@ -1,531 +0,0 @@ -# Test cases in this file are currently turned off -# Multi-route unique vindex constraint (with hash join) -"select /*vt+ ALLOW_HASH_JOIN */ user_extra.id from user join user_extra on user.col = user_extra.col where user.id = 5" -{ - "QueryType": "SELECT", - "Original": "select /*vt+ ALLOW_HASH_JOIN */ user_extra.id from user join user_extra on user.col = user_extra.col where user.id = 5", - "Instructions": { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "1", - "JoinVars": { - "user_col": 0 - }, - "TableName": "`user`_user_extra", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.col from `user` where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ `user`.col from `user` where `user`.id = 5", - "Table": "`user`", - "Values": [ - "INT64(5)" - ], - "Vindex": "user_index" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select user_extra.id from user_extra where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ user_extra.id from user_extra where user_extra.col = :user_col", - "Table": "user_extra" - } - ] - } -} -{ - "QueryType": "SELECT", - "Original": "select /*vt+ ALLOW_HASH_JOIN */ user_extra.id from user join user_extra on user.col = user_extra.col where user.id = 5", - "Instructions": { - "OperatorType": "Join", - "Variant": "HashJoin", - "ComparisonType": "INT16", - "JoinColumnIndexes": "2", - "Predicate": "`user`.col = user_extra.col", - "TableName": "`user`_user_extra", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.col from `user` where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ `user`.col from `user` where `user`.id = 5", - "Table": "`user`", - "Values": [ - "INT64(5)" - ], - "Vindex": "user_index" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select user_extra.col, user_extra.id from user_extra where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ user_extra.col, user_extra.id from user_extra", - "Table": "user_extra" - } - ] - } -} - - -# Multi-route with non-route constraint, should use first route. -"select /*vt+ ALLOW_HASH_JOIN */ user_extra.id from user join user_extra on user.col = user_extra.col where 1 = 1" -{ - "QueryType": "SELECT", - "Original": "select /*vt+ ALLOW_HASH_JOIN */ user_extra.id from user join user_extra on user.col = user_extra.col where 1 = 1", - "Instructions": { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "1", - "JoinVars": { - "user_col": 0 - }, - "TableName": "`user`_user_extra", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.col from `user` where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ `user`.col from `user` where 1 = 1", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select user_extra.id from user_extra where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ user_extra.id from user_extra where user_extra.col = :user_col", - "Table": "user_extra" - } - ] - } -} -{ - "QueryType": "SELECT", - "Original": "select /*vt+ ALLOW_HASH_JOIN */ user_extra.id from user join user_extra on user.col = user_extra.col where 1 = 1", - "Instructions": { - "OperatorType": "Join", - "Variant": "HashJoin", - "ComparisonType": "INT16", - "JoinColumnIndexes": "2", - "Predicate": "`user`.col = user_extra.col", - "TableName": "`user`_user_extra", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.col from `user` where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ `user`.col from `user` where 1 = 1", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select user_extra.col, user_extra.id from user_extra where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ user_extra.col, user_extra.id from user_extra where 1 = 1", - "Table": "user_extra" - } - ] - } -} - -# wire-up on within cross-shard derived table (hash-join version) -"select /*vt+ ALLOW_HASH_JOIN */ t.id from (select user.id, user.col1 from user join user_extra on user_extra.col = user.col) as t" -{ - "QueryType": "SELECT", - "Original": "select /*vt+ ALLOW_HASH_JOIN */ t.id from (select user.id, user.col1 from user join user_extra on user_extra.col = user.col) as t", - "Instructions": { - "OperatorType": "SimpleProjection", - "Columns": [ - 0 - ], - "Inputs": [ - { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "-1,-2", - "JoinVars": { - "user_col": 2 - }, - "TableName": "`user`_user_extra", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.id, `user`.col1, `user`.col from `user` where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ `user`.id, `user`.col1, `user`.col from `user`", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select 1 from user_extra where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ 1 from user_extra where user_extra.col = :user_col", - "Table": "user_extra" - } - ] - } - ] - } -} -{ - "QueryType": "SELECT", - "Original": "select /*vt+ ALLOW_HASH_JOIN */ t.id from (select user.id, user.col1 from user join user_extra on user_extra.col = user.col) as t", - "Instructions": { - "OperatorType": "SimpleProjection", - "Columns": [ - 0 - ], - "Inputs": [ - { - "OperatorType": "Join", - "Variant": "HashJoin", - "ComparisonType": "INT16", - "JoinColumnIndexes": "-2,-3", - "Predicate": "user_extra.col = `user`.col", - "TableName": "`user`_user_extra", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.col, `user`.id, `user`.col1 from `user` where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ `user`.col, `user`.id, `user`.col1 from `user`", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select user_extra.col from user_extra where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ user_extra.col from user_extra", - "Table": "user_extra" - } - ] - } - ] - } -} - -# hash join on int columns -"select /*vt+ ALLOW_HASH_JOIN */ u.id from user as u join user as uu on u.intcol = uu.intcol" -{ - "QueryType": "SELECT", - "Original": "select /*vt+ ALLOW_HASH_JOIN */ u.id from user as u join user as uu on u.intcol = uu.intcol", - "Instructions": { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "-1", - "JoinVars": { - "u_intcol": 1 - }, - "TableName": "`user`_`user`", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select u.id, u.intcol from `user` as u where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ u.id, u.intcol from `user` as u", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select 1 from `user` as uu where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ 1 from `user` as uu where uu.intcol = :u_intcol", - "Table": "`user`" - } - ] - } -} -{ - "QueryType": "SELECT", - "Original": "select /*vt+ ALLOW_HASH_JOIN */ u.id from user as u join user as uu on u.intcol = uu.intcol", - "Instructions": { - "OperatorType": "Join", - "Variant": "HashJoin", - "ComparisonType": "INT16", - "JoinColumnIndexes": "-2", - "Predicate": "u.intcol = uu.intcol", - "TableName": "`user`_`user`", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select u.intcol, u.id from `user` as u where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ u.intcol, u.id from `user` as u", - "Table": "`user`" - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select uu.intcol from `user` as uu where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ uu.intcol from `user` as uu", - "Table": "`user`" - } - ] - } -} - -# Author5.joins(books: [{orders: :customer}, :supplier]) (with hash join) -"select /*vt+ ALLOW_HASH_JOIN */ author5s.* from author5s join book6s on book6s.author5_id = author5s.id join book6s_order2s on book6s_order2s.book6_id = book6s.id join order2s on order2s.id = book6s_order2s.order2_id join customer2s on customer2s.id = order2s.customer2_id join supplier5s on supplier5s.id = book6s.supplier5_id" -{ - "QueryType": "SELECT", - "Original": "select /*vt+ ALLOW_HASH_JOIN */ author5s.* from author5s join book6s on book6s.author5_id = author5s.id join book6s_order2s on book6s_order2s.book6_id = book6s.id join order2s on order2s.id = book6s_order2s.order2_id join customer2s on customer2s.id = order2s.customer2_id join supplier5s on supplier5s.id = book6s.supplier5_id", - "Instructions": { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "-1,-2,-3,-4", - "JoinVars": { - "book6s_supplier5_id": 4 - }, - "TableName": "author5s, book6s_book6s_order2s_order2s_customer2s_supplier5s", - "Inputs": [ - { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "-1,-2,-3,-4,-5", - "JoinVars": { - "order2s_customer2_id": 5 - }, - "TableName": "author5s, book6s_book6s_order2s_order2s_customer2s", - "Inputs": [ - { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "-1,-2,-3,-4,-5,1", - "JoinVars": { - "book6s_order2s_order2_id": 5 - }, - "TableName": "author5s, book6s_book6s_order2s_order2s", - "Inputs": [ - { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "-1,-2,-3,-4,-5,1", - "JoinVars": { - "book6s_id": 5 - }, - "TableName": "author5s, book6s_book6s_order2s", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select author5s.id, author5s.`name`, author5s.created_at, author5s.updated_at, book6s.supplier5_id, book6s.id from author5s join book6s on book6s.author5_id = author5s.id where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ author5s.id, author5s.`name`, author5s.created_at, author5s.updated_at, book6s.supplier5_id, book6s.id from author5s join book6s on book6s.author5_id = author5s.id", - "Table": "author5s, book6s" - }, - { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select book6s_order2s.order2_id from book6s_order2s where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ book6s_order2s.order2_id from book6s_order2s where book6s_order2s.book6_id = :book6s_id", - "Table": "book6s_order2s", - "Values": [ - ":book6s_id" - ], - "Vindex": "binary_md5" - } - ] - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select order2s.customer2_id from order2s where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ order2s.customer2_id from order2s where order2s.id = :book6s_order2s_order2_id", - "Table": "order2s" - } - ] - }, - { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select 1 from customer2s where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ 1 from customer2s where customer2s.id = :order2s_customer2_id", - "Table": "customer2s", - "Values": [ - ":order2s_customer2_id" - ], - "Vindex": "binary_md5" - } - ] - }, - { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select 1 from supplier5s where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ 1 from supplier5s where supplier5s.id = :book6s_supplier5_id", - "Table": "supplier5s", - "Values": [ - ":book6s_supplier5_id" - ], - "Vindex": "binary_md5" - } - ] - } -} -{ - "QueryType": "SELECT", - "Original": "select /*vt+ ALLOW_HASH_JOIN */ author5s.* from author5s join book6s on book6s.author5_id = author5s.id join book6s_order2s on book6s_order2s.book6_id = book6s.id join order2s on order2s.id = book6s_order2s.order2_id join customer2s on customer2s.id = order2s.customer2_id join supplier5s on supplier5s.id = book6s.supplier5_id", - "Instructions": { - "OperatorType": "Join", - "Variant": "HashJoin", - "ComparisonType": "INT64", - "JoinColumnIndexes": "2,3,4,5", - "Predicate": "order2s.id = book6s_order2s.order2_id", - "TableName": "customer2s, order2s_author5s, book6s_book6s_order2s_supplier5s", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select order2s.id from order2s, customer2s where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ order2s.id from order2s, customer2s where customer2s.id = order2s.customer2_id", - "Table": "customer2s, order2s" - }, - { - "OperatorType": "Join", - "Variant": "HashJoin", - "ComparisonType": "INT64", - "JoinColumnIndexes": "-1,-2,-3,-4,-5", - "Predicate": "supplier5s.id = book6s.supplier5_id", - "TableName": "author5s, book6s_book6s_order2s_supplier5s", - "Inputs": [ - { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "1,-3,-4,-5,-6", - "JoinVars": { - "book6s_id": 0 - }, - "Predicate": "book6s_order2s.book6_id = book6s.id", - "TableName": "author5s, book6s_book6s_order2s", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select book6s.id, book6s.supplier5_id, author5s.id as id, author5s.`name` as `name`, author5s.created_at as created_at, author5s.updated_at as updated_at from author5s, book6s where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ book6s.id, book6s.supplier5_id, author5s.id as id, author5s.`name` as `name`, author5s.created_at as created_at, author5s.updated_at as updated_at from author5s, book6s where book6s.author5_id = author5s.id", - "Table": "author5s, book6s" - }, - { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select book6s_order2s.order2_id from book6s_order2s where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ book6s_order2s.order2_id from book6s_order2s where book6s_order2s.book6_id = :book6s_id", - "Table": "book6s_order2s", - "Values": [ - ":book6s_id" - ], - "Vindex": "binary_md5" - } - ] - }, - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select supplier5s.id from supplier5s where 1 != 1", - "Query": "select /*vt+ ALLOW_HASH_JOIN */ supplier5s.id from supplier5s", - "Table": "supplier5s" - } - ] - } - ] - } -} diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json index 1ea07455486..fdab7835738 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json @@ -330,14 +330,9 @@ "plan": "VT12001: unsupported: unmergable subquery can not be inside complex expression" }, { - "comment": "cant switch sides for outer joins", - "query": "select id from user left join (select user_id from user_extra limit 10) ue on user.id = ue.user_id", - "plan": "VT12001: unsupported: LEFT JOIN with LIMIT on the outer side" - }, - { - "comment": "limit on both sides means that we can't evaluate this at all", + "comment": "this query needs better type information to be able to use the hash join", "query": "select id from (select id from user limit 10) u join (select user_id from user_extra limit 10) ue on u.id = ue.user_id", - "plan": "VT12001: unsupported: JOIN between derived tables with LIMIT" + "plan": "VT12001: unsupported: missing type information for [u.id, ue.user_id]" }, { "comment": "multi-shard union", diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 48ab4322bc8..48060524dba 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -76,10 +76,11 @@ type ( // QuerySignature is used to identify shortcuts in the planning process QuerySignature struct { - Union bool - Aggregation bool - Distinct bool - SubQueries bool + Union, + Aggregation, + Distinct, + SubQueries, + HashJoin bool } // SemTable contains semantic analysis information about the query. @@ -505,6 +506,7 @@ func EmptySemTable() *SemTable { Direct: map[sqlparser.Expr]TableSet{}, ColumnEqualities: map[columnName][]sqlparser.Expr{}, columns: map[*sqlparser.Union]sqlparser.SelectExprs{}, + ExprTypes: make(map[sqlparser.Expr]evalengine.Type), } } diff --git a/go/vt/vttablet/tabletserver/rules/rules.go b/go/vt/vttablet/tabletserver/rules/rules.go index efbfcdf87e4..c1e572caa2c 100644 --- a/go/vt/vttablet/tabletserver/rules/rules.go +++ b/go/vt/vttablet/tabletserver/rules/rules.go @@ -27,15 +27,14 @@ import ( "time" "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vttablet/tabletserver/planbuilder" - - querypb "vitess.io/vitess/go/vt/proto/query" - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) -//----------------------------------------------- +// ----------------------------------------------- const ( bufferedTableRuleName = "buffered_table" @@ -189,7 +188,7 @@ func (qrs *Rules) GetAction( return QRContinue, nil, 0, "" } -//----------------------------------------------- +// ----------------------------------------------- // Rule represents one rule (conditions-action). // Name is meant to uniquely identify a rule. @@ -561,7 +560,7 @@ func bvMatch(bvcond BindVarCond, bindVars map[string]*querypb.BindVariable) bool return bvcond.value.eval(bv, bvcond.op, bvcond.onMismatch) } -//----------------------------------------------- +// ----------------------------------------------- // Support types for Rule // Action speficies the list of actions to perform @@ -852,13 +851,13 @@ func getint64(val *querypb.BindVariable) (iv int64, status int) { // TODO(sougou): this is inefficient. Optimize to use []byte. func getstring(val *querypb.BindVariable) (s string, status int) { - if sqltypes.IsIntegral(val.Type) || sqltypes.IsFloat(val.Type) || sqltypes.IsText(val.Type) || sqltypes.IsBinary(val.Type) { + if sqltypes.IsIntegral(val.Type) || sqltypes.IsFloat(val.Type) || sqltypes.IsTextOrBinary(val.Type) { return string(val.Value), QROK } return "", QRMismatch } -//----------------------------------------------- +// ----------------------------------------------- // Support functions for JSON // MapStrOperator maps a string representation to an Operator.