diff --git a/dbms/src/Debug/MockExecutor/AggregationBinder.cpp b/dbms/src/Debug/MockExecutor/AggregationBinder.cpp index 7b5d9b9d134..e95346af901 100644 --- a/dbms/src/Debug/MockExecutor/AggregationBinder.cpp +++ b/dbms/src/Debug/MockExecutor/AggregationBinder.cpp @@ -28,6 +28,7 @@ bool AggregationBinder::toTiPBExecutor(tipb::Executor * tipb_executor, int32_t c { tipb_executor->set_tp(tipb::ExecType::TypeAggregation); tipb_executor->set_executor_id(name); + tipb_executor->set_fine_grained_shuffle_stream_count(fine_grained_shuffle_stream_count); auto * agg = tipb_executor->mutable_aggregation(); buildAggExpr(agg, collator_id, context); buildGroupBy(agg, collator_id, context); @@ -80,7 +81,8 @@ void AggregationBinder::toMPPSubPlan(size_t & executor_index, const DAGPropertie false, std::move(agg_exprs), std::move(gby_exprs), - false); + false, + fine_grained_shuffle_stream_count); partial_agg->children.push_back(children[0]); std::vector partition_keys; size_t agg_func_num = partial_agg->agg_exprs.size(); @@ -206,7 +208,7 @@ void AggregationBinder::buildAggFunc(tipb::Expr * agg_func, const ASTFunction * agg_func->set_aggfuncmode(tipb::AggFunctionMode::Partial1Mode); } -ExecutorBinderPtr compileAggregation(ExecutorBinderPtr input, size_t & executor_index, ASTPtr agg_funcs, ASTPtr group_by_exprs) +ExecutorBinderPtr compileAggregation(ExecutorBinderPtr input, size_t & executor_index, ASTPtr agg_funcs, ASTPtr group_by_exprs, uint64_t fine_grained_shuffle_stream_count) { std::vector agg_exprs; std::vector gby_exprs; @@ -276,7 +278,8 @@ ExecutorBinderPtr compileAggregation(ExecutorBinderPtr input, size_t & executor_ need_append_project, std::move(agg_exprs), std::move(gby_exprs), - true); + true, + fine_grained_shuffle_stream_count); aggregation->children.push_back(input); return aggregation; } diff --git a/dbms/src/Debug/MockExecutor/AggregationBinder.h b/dbms/src/Debug/MockExecutor/AggregationBinder.h index 84821594988..005549e6f0b 100644 --- a/dbms/src/Debug/MockExecutor/AggregationBinder.h +++ b/dbms/src/Debug/MockExecutor/AggregationBinder.h @@ -25,13 +25,14 @@ class ExchangeReceiverBinder; class AggregationBinder : public ExecutorBinder { public: - AggregationBinder(size_t & index_, const DAGSchema & output_schema_, bool has_uniq_raw_res_, bool need_append_project_, ASTs && agg_exprs_, ASTs && gby_exprs_, bool is_final_mode_) + AggregationBinder(size_t & index_, const DAGSchema & output_schema_, bool has_uniq_raw_res_, bool need_append_project_, ASTs && agg_exprs_, ASTs && gby_exprs_, bool is_final_mode_, uint64_t fine_grained_shuffle_stream_count_) : ExecutorBinder(index_, "aggregation_" + std::to_string(index_), output_schema_) , has_uniq_raw_res(has_uniq_raw_res_) , need_append_project(need_append_project_) , agg_exprs(std::move(agg_exprs_)) , gby_exprs(std::move(gby_exprs_)) , is_final_mode(is_final_mode_) + , fine_grained_shuffle_stream_count(fine_grained_shuffle_stream_count_) {} bool toTiPBExecutor(tipb::Executor * tipb_executor, int32_t collator_id, const MPPInfo & mpp_info, const Context & context) override; @@ -53,6 +54,7 @@ class AggregationBinder : public ExecutorBinder std::vector gby_exprs; bool is_final_mode; DAGSchema output_schema_for_partial_agg; + uint64_t fine_grained_shuffle_stream_count; private: void buildGroupBy(tipb::Aggregation * agg, int32_t collator_id, const Context & context) const; @@ -60,6 +62,6 @@ class AggregationBinder : public ExecutorBinder void buildAggFunc(tipb::Expr * agg_func, const ASTFunction * func, int32_t collator_id) const; }; -ExecutorBinderPtr compileAggregation(ExecutorBinderPtr input, size_t & executor_index, ASTPtr agg_funcs, ASTPtr group_by_exprs); +ExecutorBinderPtr compileAggregation(ExecutorBinderPtr input, size_t & executor_index, ASTPtr agg_funcs, ASTPtr group_by_exprs, uint64_t fine_grained_shuffle_stream_count = 0); } // namespace DB::mock diff --git a/dbms/src/Debug/MockExecutor/JoinBinder.cpp b/dbms/src/Debug/MockExecutor/JoinBinder.cpp index df0f11c2133..e9bc36bc5d0 100644 --- a/dbms/src/Debug/MockExecutor/JoinBinder.cpp +++ b/dbms/src/Debug/MockExecutor/JoinBinder.cpp @@ -140,6 +140,7 @@ bool JoinBinder::toTiPBExecutor(tipb::Executor * tipb_executor, int32_t collator { tipb_executor->set_tp(tipb::ExecType::TypeJoin); tipb_executor->set_executor_id(name); + tipb_executor->set_fine_grained_shuffle_stream_count(fine_grained_shuffle_stream_count); tipb::Join * join = tipb_executor->mutable_join(); @@ -288,14 +289,15 @@ ExecutorBinderPtr compileJoin(size_t & executor_index, const ASTs & left_conds, const ASTs & right_conds, const ASTs & other_conds, - const ASTs & other_eq_conds_from_in) + const ASTs & other_eq_conds_from_in, + uint64_t fine_grained_shuffle_stream_count) { DAGSchema output_schema; buildLeftSideJoinSchema(output_schema, left->output_schema, tp); buildRightSideJoinSchema(output_schema, right->output_schema, tp); - auto join = std::make_shared(executor_index, output_schema, tp, join_cols, left_conds, right_conds, other_conds, other_eq_conds_from_in); + auto join = std::make_shared(executor_index, output_schema, tp, join_cols, left_conds, right_conds, other_conds, other_eq_conds_from_in, fine_grained_shuffle_stream_count); join->children.push_back(left); join->children.push_back(right); diff --git a/dbms/src/Debug/MockExecutor/JoinBinder.h b/dbms/src/Debug/MockExecutor/JoinBinder.h index c649420b8a9..cbdcd9d25b9 100644 --- a/dbms/src/Debug/MockExecutor/JoinBinder.h +++ b/dbms/src/Debug/MockExecutor/JoinBinder.h @@ -23,7 +23,7 @@ class ExchangeReceiverBinder; class JoinBinder : public ExecutorBinder { public: - JoinBinder(size_t & index_, const DAGSchema & output_schema_, tipb::JoinType tp_, const ASTs & join_cols_, const ASTs & l_conds, const ASTs & r_conds, const ASTs & o_conds, const ASTs & o_eq_conds) + JoinBinder(size_t & index_, const DAGSchema & output_schema_, tipb::JoinType tp_, const ASTs & join_cols_, const ASTs & l_conds, const ASTs & r_conds, const ASTs & o_conds, const ASTs & o_eq_conds, uint64_t fine_grained_shuffle_stream_count_) : ExecutorBinder(index_, "Join_" + std::to_string(index_), output_schema_) , tp(tp_) , join_cols(join_cols_) @@ -31,6 +31,7 @@ class JoinBinder : public ExecutorBinder , right_conds(r_conds) , other_conds(o_conds) , other_eq_conds_from_in(o_eq_conds) + , fine_grained_shuffle_stream_count(fine_grained_shuffle_stream_count_) { if (!(join_cols.size() + left_conds.size() + right_conds.size() + other_conds.size() + other_eq_conds_from_in.size())) throw Exception("No join condition found."); @@ -57,9 +58,10 @@ class JoinBinder : public ExecutorBinder const ASTs right_conds{}; const ASTs other_conds{}; const ASTs other_eq_conds_from_in{}; + uint64_t fine_grained_shuffle_stream_count; }; // compileJoin constructs a mocked Join executor node, note that all conditional expression params can be default -ExecutorBinderPtr compileJoin(size_t & executor_index, ExecutorBinderPtr left, ExecutorBinderPtr right, tipb::JoinType tp, const ASTs & join_cols, const ASTs & left_conds = {}, const ASTs & right_conds = {}, const ASTs & other_conds = {}, const ASTs & other_eq_conds_from_in = {}); +ExecutorBinderPtr compileJoin(size_t & executor_index, ExecutorBinderPtr left, ExecutorBinderPtr right, tipb::JoinType tp, const ASTs & join_cols, const ASTs & left_conds = {}, const ASTs & right_conds = {}, const ASTs & other_conds = {}, const ASTs & other_eq_conds_from_in = {}, uint64_t fine_grained_shuffle_stream_count = 0); /// Note: this api is only used by legacy test framework for compatibility purpose, which will be depracated soon, diff --git a/dbms/src/Flash/tests/gtest_compute_server.cpp b/dbms/src/Flash/tests/gtest_compute_server.cpp index e90e88f2289..ab53fe00392 100644 --- a/dbms/src/Flash/tests/gtest_compute_server.cpp +++ b/dbms/src/Flash/tests/gtest_compute_server.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include namespace DB @@ -34,6 +35,24 @@ class ComputeServerRunner : public DB::tests::MPPTaskTestUtils {{"s1", TiDB::TP::TypeLong}, {"s2", TiDB::TP::TypeString}, {"s3", TiDB::TP::TypeString}}, {toNullableVec("s1", {1, {}, 10000000, 10000000}), toNullableVec("s2", {"apple", {}, "banana", "test"}), toNullableVec("s3", {"apple", {}, "banana", "test"})}); + /// agg table with 200 rows + std::vector::FieldType>> agg_s1(200); + std::vector> agg_s2(200); + std::vector> agg_s3(200); + for (size_t i = 0; i < 200; ++i) + { + if (i % 30 != 0) + { + agg_s1[i] = i % 20; + agg_s2[i] = {fmt::format("val_{}", i % 10)}; + agg_s3[i] = {fmt::format("val_{}", i)}; + } + } + context.addMockTable( + {"test_db", "test_table_2"}, + {{"s1", TiDB::TP::TypeLong}, {"s2", TiDB::TP::TypeString}, {"s3", TiDB::TP::TypeString}}, + {toNullableVec("s1", agg_s1), toNullableVec("s2", agg_s2), toNullableVec("s3", agg_s3)}); + /// for join context.addMockTable( {"test_db", "l_table"}, @@ -43,9 +62,46 @@ class ComputeServerRunner : public DB::tests::MPPTaskTestUtils {"test_db", "r_table"}, {{"s", TiDB::TP::TypeString}, {"join_c", TiDB::TP::TypeString}}, {toNullableVec("s", {"banana", {}, "banana"}), toNullableVec("join_c", {"apple", {}, "banana"})}); + + /// join left table with 200 rows + std::vector::FieldType>> join_s1(200); + std::vector> join_s2(200); + std::vector> join_s3(200); + for (size_t i = 0; i < 200; ++i) + { + if (i % 20 != 0) + { + agg_s1[i] = i % 5; + agg_s2[i] = {fmt::format("val_{}", i % 6)}; + agg_s3[i] = {fmt::format("val_{}", i)}; + } + } + context.addMockTable( + {"test_db", "l_table_2"}, + {{"s1", TiDB::TP::TypeLong}, {"s2", TiDB::TP::TypeString}, {"s3", TiDB::TP::TypeString}}, + {toNullableVec("s1", agg_s1), toNullableVec("s2", agg_s2), toNullableVec("s3", agg_s3)}); + + /// join right table with 100 rows + std::vector::FieldType>> join_r_s1(100); + std::vector> join_r_s2(100); + std::vector> join_r_s3(100); + for (size_t i = 0; i < 100; ++i) + { + if (i % 20 != 0) + { + join_r_s1[i] = i % 6; + join_r_s2[i] = {fmt::format("val_{}", i % 7)}; + join_r_s3[i] = {fmt::format("val_{}", i)}; + } + } + context.addMockTable( + {"test_db", "r_table_2"}, + {{"s1", TiDB::TP::TypeLong}, {"s2", TiDB::TP::TypeString}, {"s3", TiDB::TP::TypeString}}, + {toNullableVec("s1", join_r_s1), toNullableVec("s2", join_r_s2), toNullableVec("s3", join_r_s3)}); } }; + TEST_F(ComputeServerRunner, runAggTasks) try { @@ -445,5 +501,121 @@ try } } CATCH + +/// For FineGrainedShuffleJoin/Agg test usage, update internal exchange senders/receivers flag +/// Allow select,agg,join,tableScan,exchangeSender,exchangeReceiver,projection executors only +void setFineGrainedShuffleForExchange(tipb::Executor & root) +{ + tipb::Executor * current = &root; + while (current) + { + switch (current->tp()) + { + case tipb::ExecType::TypeSelection: + current = const_cast(¤t->selection().child()); + break; + case tipb::ExecType::TypeAggregation: + current = const_cast(¤t->aggregation().child()); + break; + case tipb::ExecType::TypeProjection: + current = const_cast(¤t->projection().child()); + break; + case tipb::ExecType::TypeJoin: + { + /// update build side path + JoinInterpreterHelper::TiFlashJoin tiflash_join{current->join()}; + current = const_cast(¤t->join().children()[tiflash_join.build_side_index]); + break; + } + case tipb::ExecType::TypeExchangeSender: + if (current->exchange_sender().tp() == tipb::Hash) + current->set_fine_grained_shuffle_stream_count(8); + current = const_cast(¤t->exchange_sender().child()); + break; + case tipb::ExecType::TypeExchangeReceiver: + current->set_fine_grained_shuffle_stream_count(8); + current = nullptr; + break; + case tipb::ExecType::TypeTableScan: + current = nullptr; + break; + default: + throw TiFlashException("Should not reach here", Errors::Coprocessor::Internal); + } + } +} + +TEST_F(ComputeServerRunner, runFineGrainedShuffleJoinTest) +try +{ + startServers(3); + constexpr size_t join_type_num = 7; + constexpr tipb::JoinType join_types[join_type_num] = { + tipb::JoinType::TypeInnerJoin, + tipb::JoinType::TypeLeftOuterJoin, + tipb::JoinType::TypeRightOuterJoin, + tipb::JoinType::TypeSemiJoin, + tipb::JoinType::TypeAntiSemiJoin, + tipb::JoinType::TypeLeftOuterSemiJoin, + tipb::JoinType::TypeAntiLeftOuterSemiJoin, + }; + // fine-grained shuffle is enabled. + constexpr uint64_t enable = 8; + constexpr uint64_t disable = 0; + + for (auto join_type : join_types) + { + std::cout << "JoinType: " << static_cast(join_type) << std::endl; + auto properties = DB::tests::getDAGPropertiesForTest(serverNum()); + auto request = context + .scan("test_db", "l_table_2") + .join(context.scan("test_db", "r_table_2"), join_type, {col("s1"), col("s2")}, disable) + .project({col("l_table_2.s1"), col("l_table_2.s2"), col("l_table_2.s3")}); + const auto expected_cols = buildAndExecuteMPPTasks(request); + + auto request2 = context + .scan("test_db", "l_table_2") + .join(context.scan("test_db", "r_table_2"), join_type, {col("s1"), col("s2")}, enable) + .project({col("l_table_2.s1"), col("l_table_2.s2"), col("l_table_2.s3")}); + auto tasks = request2.buildMPPTasks(context, properties); + for (auto & task : tasks) + { + setFineGrainedShuffleForExchange(const_cast(task.dag_request->root_executor())); + } + const auto actual_cols = executeMPPTasks(tasks, properties, MockComputeServerManager::instance().getServerConfigMap()); + ASSERT_COLUMNS_EQ_UR(expected_cols, actual_cols); + } +} +CATCH + +TEST_F(ComputeServerRunner, runFineGrainedShuffleAggTest) +try +{ + startServers(3); + // fine-grained shuffle is enabled. + constexpr uint64_t enable = 8; + constexpr uint64_t disable = 0; + { + auto properties = DB::tests::getDAGPropertiesForTest(serverNum()); + auto request = context + .scan("test_db", "test_table_2") + .aggregation({Max(col("s3"))}, {col("s1"), col("s2")}, disable); + const auto expected_cols = buildAndExecuteMPPTasks(request); + + auto request2 = context + .scan("test_db", "test_table_2") + .aggregation({Max(col("s3"))}, {col("s1"), col("s2")}, enable); + auto tasks = request2.buildMPPTasks(context, properties); + for (auto & task : tasks) + { + setFineGrainedShuffleForExchange(const_cast(task.dag_request->root_executor())); + } + + const auto actual_cols = executeMPPTasks(tasks, properties, MockComputeServerManager::instance().getServerConfigMap()); + ASSERT_COLUMNS_EQ_UR(expected_cols, actual_cols); + } +} +CATCH + } // namespace tests } // namespace DB diff --git a/dbms/src/Flash/tests/gtest_interpreter.cpp b/dbms/src/Flash/tests/gtest_interpreter.cpp index 736166929bc..0afa65390ac 100644 --- a/dbms/src/Flash/tests/gtest_interpreter.cpp +++ b/dbms/src/Flash/tests/gtest_interpreter.cpp @@ -391,12 +391,110 @@ Union: } CATCH +TEST_F(InterpreterExecuteTest, FineGrainedShuffleJoin) +try +{ + // fine-grained shuffle is enabled. + const uint64_t enable = 8; + const uint64_t disable = 0; + { + // Join Source. + DAGRequestBuilder receiver1 = context.receive("sender_l"); + DAGRequestBuilder receiver2 = context.receive("sender_r", enable); + + auto request = receiver1.join( + receiver2, + tipb::JoinType::TypeLeftOuterJoin, + {col("join_c")}, + enable) + .build(context); + + String expected = R"( +CreatingSets + Union: + HashJoinBuild x 10: , join_kind = Left + Expression: + Expression: + MockExchangeReceiver + Union: + Expression x 10: + Expression: + HashJoinProbe: + Expression: + MockExchangeReceiver)"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); + } + { + // Join Source. + DAGRequestBuilder receiver1 = context.receive("sender_l"); + DAGRequestBuilder receiver2 = context.receive("sender_r", disable); + + auto request = receiver1.join( + receiver2, + tipb::JoinType::TypeLeftOuterJoin, + {col("join_c")}, + disable) + .build(context); + + String expected = R"( +CreatingSets + Union: + HashJoinBuild x 10: , join_kind = Left + Expression: + Expression: + MockExchangeReceiver + Union: + Expression x 10: + Expression: + HashJoinProbe: + Expression: + MockExchangeReceiver)"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); + } +} +CATCH + +TEST_F(InterpreterExecuteTest, FineGrainedShuffleAgg) +try +{ + // fine-grained shuffle is enabled. + const uint64_t enable = 8; + const uint64_t disable = 0; + { + DAGRequestBuilder receiver1 = context.receive("sender_1", enable); + auto request = receiver1 + .aggregation({Max(col("s1"))}, {col("s2")}, enable) + .build(context); + String expected = R"( +Union: + Expression x 10: + Aggregating: + MockExchangeReceiver)"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); + } + + { + DAGRequestBuilder receiver1 = context.receive("sender_1", disable); + auto request = receiver1 + .aggregation({Max(col("s1"))}, {col("s2")}, disable) + .build(context); + String expected = R"( +Union: + Expression x 10: + SharedQuery: + ParallelAggregating, max_threads: 10, final: true + MockExchangeReceiver x 10)"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); + } +} +CATCH + TEST_F(InterpreterExecuteTest, Join) try { // TODO: Find a way to write the request easier. { - // Join Source. + // join + ExchangeReceiver DAGRequestBuilder table1 = context.scan("test_db", "r_table"); DAGRequestBuilder table2 = context.scan("test_db", "l_table"); DAGRequestBuilder table3 = context.scan("test_db", "r_table"); diff --git a/dbms/src/Flash/tests/gtest_planner_interpreter.cpp b/dbms/src/Flash/tests/gtest_planner_interpreter.cpp index e9f99891642..eb6de71ca4e 100644 --- a/dbms/src/Flash/tests/gtest_planner_interpreter.cpp +++ b/dbms/src/Flash/tests/gtest_planner_interpreter.cpp @@ -723,6 +723,108 @@ Union: } CATCH +TEST_F(PlannerInterpreterExecuteTest, FineGrainedShuffleJoin) +try +{ + // fine-grained shuffle is enabled. + const uint64_t enable = 8; + const uint64_t disable = 0; + { + // Join Source. + DAGRequestBuilder receiver1 = context.receive("sender_l"); + DAGRequestBuilder receiver2 = context.receive("sender_r", enable); + + auto request = receiver1.join( + receiver2, + tipb::JoinType::TypeLeftOuterJoin, + {col("join_c")}, + enable) + .build(context); + + String expected = R"( +CreatingSets + Union: + HashJoinBuild x 10: , join_kind = Left + Expression: + Expression: + MockExchangeReceiver + Union: + Expression x 10: + Expression: + HashJoinProbe: + Expression: + MockExchangeReceiver)"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); + } + { + // Join Source. + DAGRequestBuilder receiver1 = context.receive("sender_l"); + DAGRequestBuilder receiver2 = context.receive("sender_r", disable); + + auto request = receiver1.join( + receiver2, + tipb::JoinType::TypeLeftOuterJoin, + {col("join_c")}, + disable) + .build(context); + + String expected = R"( +CreatingSets + Union: + HashJoinBuild x 10: , join_kind = Left + Expression: + Expression: + MockExchangeReceiver + Union: + Expression x 10: + Expression: + HashJoinProbe: + Expression: + MockExchangeReceiver)"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); + } +} +CATCH + +TEST_F(PlannerInterpreterExecuteTest, FineGrainedShuffleAgg) +try +{ + // fine-grained shuffle is enabled. + const uint64_t enable = 8; + const uint64_t disable = 0; + { + DAGRequestBuilder receiver1 = context.receive("sender_1", enable); + auto request = receiver1 + .aggregation({Max(col("s1"))}, {col("s2")}, enable) + .build(context); + String expected = R"( +Union: + Expression x 10: + Expression: + Aggregating: + Expression: + MockExchangeReceiver)"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); + } + + { + DAGRequestBuilder receiver1 = context.receive("sender_1", disable); + auto request = receiver1 + .aggregation({Max(col("s1"))}, {col("s2")}, disable) + .build(context); + String expected = R"( +Union: + Expression x 10: + Expression: + SharedQuery: + ParallelAggregating, max_threads: 10, final: true + Expression x 10: + MockExchangeReceiver)"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); + } +} +CATCH + TEST_F(PlannerInterpreterExecuteTest, Join) try { diff --git a/dbms/src/Interpreters/Join.cpp b/dbms/src/Interpreters/Join.cpp index 15d94e0c195..102fa05f9f6 100644 --- a/dbms/src/Interpreters/Join.cpp +++ b/dbms/src/Interpreters/Join.cpp @@ -1214,7 +1214,10 @@ void NO_INLINE joinBlockImplTypeCase( /// 2. In ExchangeReceiver, build_stream_id = packet_stream_id % build_stream_count; /// 3. In HashBuild, build_concurrency decides map's segment size, and build_steam_id decides the segment index auto packet_stream_id = shuffle_hash_data[i] % fine_grained_shuffle_count; - segment_index = packet_stream_id % segment_size; + if likely (fine_grained_shuffle_count == segment_size) + segment_index = packet_stream_id; + else + segment_index = packet_stream_id % segment_size; } else { diff --git a/dbms/src/TestUtils/mockExecutor.cpp b/dbms/src/TestUtils/mockExecutor.cpp index 59000185cdf..66fabb59fbd 100644 --- a/dbms/src/TestUtils/mockExecutor.cpp +++ b/dbms/src/TestUtils/mockExecutor.cpp @@ -279,16 +279,16 @@ DAGRequestBuilder & DAGRequestBuilder::join( MockAstVec left_conds, MockAstVec right_conds, MockAstVec other_conds, - MockAstVec other_eq_conds_from_in) + MockAstVec other_eq_conds_from_in, + uint64_t fine_grained_shuffle_stream_count) { assert(root); assert(right.root); - - root = mock::compileJoin(getExecutorIndex(), root, right.root, tp, join_col_exprs, left_conds, right_conds, other_conds, other_eq_conds_from_in); + root = mock::compileJoin(getExecutorIndex(), root, right.root, tp, join_col_exprs, left_conds, right_conds, other_conds, other_eq_conds_from_in, fine_grained_shuffle_stream_count); return *this; } -DAGRequestBuilder & DAGRequestBuilder::aggregation(ASTPtr agg_func, ASTPtr group_by_expr) +DAGRequestBuilder & DAGRequestBuilder::aggregation(ASTPtr agg_func, ASTPtr group_by_expr, uint64_t fine_grained_shuffle_stream_count) { auto agg_funcs = std::make_shared(); auto group_by_exprs = std::make_shared(); @@ -296,10 +296,10 @@ DAGRequestBuilder & DAGRequestBuilder::aggregation(ASTPtr agg_func, ASTPtr group agg_funcs->children.push_back(agg_func); if (group_by_expr) group_by_exprs->children.push_back(group_by_expr); - return buildAggregation(agg_funcs, group_by_exprs); + return buildAggregation(agg_funcs, group_by_exprs, fine_grained_shuffle_stream_count); } -DAGRequestBuilder & DAGRequestBuilder::aggregation(MockAstVec agg_funcs, MockAstVec group_by_exprs) +DAGRequestBuilder & DAGRequestBuilder::aggregation(MockAstVec agg_funcs, MockAstVec group_by_exprs, uint64_t fine_grained_shuffle_stream_count) { auto agg_func_list = std::make_shared(); auto group_by_expr_list = std::make_shared(); @@ -307,13 +307,13 @@ DAGRequestBuilder & DAGRequestBuilder::aggregation(MockAstVec agg_funcs, MockAst agg_func_list->children.push_back(func); for (const auto & group_by : group_by_exprs) group_by_expr_list->children.push_back(group_by); - return buildAggregation(agg_func_list, group_by_expr_list); + return buildAggregation(agg_func_list, group_by_expr_list, fine_grained_shuffle_stream_count); } -DAGRequestBuilder & DAGRequestBuilder::buildAggregation(ASTPtr agg_funcs, ASTPtr group_by_exprs) +DAGRequestBuilder & DAGRequestBuilder::buildAggregation(ASTPtr agg_funcs, ASTPtr group_by_exprs, uint64_t fine_grained_shuffle_stream_count) { assert(root); - root = compileAggregation(root, getExecutorIndex(), agg_funcs, group_by_exprs); + root = compileAggregation(root, getExecutorIndex(), agg_funcs, group_by_exprs, fine_grained_shuffle_stream_count); return *this; } diff --git a/dbms/src/TestUtils/mockExecutor.h b/dbms/src/TestUtils/mockExecutor.h index 14b314d9c20..8c9b2697ee3 100644 --- a/dbms/src/TestUtils/mockExecutor.h +++ b/dbms/src/TestUtils/mockExecutor.h @@ -122,16 +122,17 @@ class DAGRequestBuilder /// @param right_conds conditional expressions which only reference right table and the join type is right kind /// @param other_conds other conditional expressions /// @param other_eq_conds_from_in equality expressions within in subquery whose join type should be AntiSemiJoin, AntiLeftOuterSemiJoin or LeftOuterSemiJoin - DAGRequestBuilder & join(const DAGRequestBuilder & right, tipb::JoinType tp, MockAstVec join_col_exprs, MockAstVec left_conds, MockAstVec right_conds, MockAstVec other_conds, MockAstVec other_eq_conds_from_in); - DAGRequestBuilder & join(const DAGRequestBuilder & right, tipb::JoinType tp, MockAstVec join_col_exprs) + /// @param fine_grained_shuffle_stream_count decide the generated tipb executor's find_grained_shuffle_stream_count + DAGRequestBuilder & join(const DAGRequestBuilder & right, tipb::JoinType tp, MockAstVec join_col_exprs, MockAstVec left_conds, MockAstVec right_conds, MockAstVec other_conds, MockAstVec other_eq_conds_from_in, uint64_t fine_grained_shuffle_stream_count = 0); + DAGRequestBuilder & join(const DAGRequestBuilder & right, tipb::JoinType tp, MockAstVec join_col_exprs, uint64_t fine_grained_shuffle_stream_count = 0) { - return join(right, tp, join_col_exprs, {}, {}, {}, {}); + return join(right, tp, join_col_exprs, {}, {}, {}, {}, fine_grained_shuffle_stream_count); } // aggregation - DAGRequestBuilder & aggregation(ASTPtr agg_func, ASTPtr group_by_expr); - DAGRequestBuilder & aggregation(MockAstVec agg_funcs, MockAstVec group_by_exprs); + DAGRequestBuilder & aggregation(ASTPtr agg_func, ASTPtr group_by_expr, uint64_t fine_grained_shuffle_stream_count = 0); + DAGRequestBuilder & aggregation(MockAstVec agg_funcs, MockAstVec group_by_exprs, uint64_t fine_grained_shuffle_stream_count = 0); // window DAGRequestBuilder & window(ASTPtr window_func, MockOrderByItem order_by, MockPartitionByItem partition_by, MockWindowFrame frame, uint64_t fine_grained_shuffle_stream_count = 0); @@ -145,7 +146,7 @@ class DAGRequestBuilder private: void initDAGRequest(tipb::DAGRequest & dag_request); - DAGRequestBuilder & buildAggregation(ASTPtr agg_funcs, ASTPtr group_by_exprs); + DAGRequestBuilder & buildAggregation(ASTPtr agg_funcs, ASTPtr group_by_exprs, uint64_t fine_grained_shuffle_stream_count = 0); DAGRequestBuilder & buildExchangeReceiver(const MockColumnInfoVec & columns, uint64_t fine_grained_shuffle_stream_count = 0); mock::ExecutorBinderPtr root;