Skip to content

Commit

Permalink
Fix potential hang in right join (pingcap#6690)
Browse files Browse the repository at this point in the history
  • Loading branch information
windtalker authored Jan 30, 2023
1 parent c5558c4 commit bca56b9
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 66 deletions.
1 change: 1 addition & 0 deletions dbms/src/Common/FailPoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ std::unordered_map<String, std::shared_ptr<FailPointChannel>> FailPointHelper::f
M(exception_in_creating_set_input_stream) \
M(exception_when_read_from_log) \
M(exception_mpp_hash_build) \
M(exception_mpp_hash_probe) \
M(exception_before_drop_segment) \
M(exception_after_drop_segment) \
M(exception_between_schema_change_in_the_same_diff) \
Expand Down
12 changes: 2 additions & 10 deletions dbms/src/DataStreams/CreatingSetsBlockInputStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ namespace DB
namespace FailPoints
{
extern const char exception_in_creating_set_input_stream[];
extern const char exception_mpp_hash_build[];
} // namespace FailPoints
namespace ErrorCodes
{
Expand Down Expand Up @@ -120,7 +119,7 @@ void CreatingSetsBlockInputStream::createAll()
for (auto & elem : subqueries_for_sets)
{
if (elem.second.join)
elem.second.join->setBuildTableState(Join::BuildTableState::WAITING);
elem.second.join->setInitActiveBuildConcurrency();
}
}
Stopwatch watch;
Expand Down Expand Up @@ -238,13 +237,6 @@ void CreatingSetsBlockInputStream::createOne(SubqueryForSet & subquery)
}
}


if (subquery.join)
{
FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_mpp_hash_build);
subquery.join->setBuildTableState(Join::BuildTableState::SUCCEED);
}

if (table_out)
table_out->writeSuffix();

Expand Down Expand Up @@ -294,7 +286,7 @@ void CreatingSetsBlockInputStream::createOne(SubqueryForSet & subquery)
std::unique_lock lock(exception_mutex);
exception_from_workers.push_back(std::current_exception());
if (subquery.join)
subquery.join->setBuildTableState(Join::BuildTableState::FAILED);
subquery.join->meetError();
LOG_ERROR(log, "{} throw exception: {} In {} sec. ", gen_log_msg(), getCurrentExceptionMessage(false, true), watch.elapsedSeconds());
}
}
Expand Down
19 changes: 15 additions & 4 deletions dbms/src/DataStreams/HashJoinBuildBlockInputStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,22 @@ namespace DB
{
Block HashJoinBuildBlockInputStream::readImpl()
{
Block block = children.back()->read();
if (!block)
try
{
Block block = children.back()->read();
if (!block)
{
join->finishOneBuild();
return block;
}
join->insertFromBlock(block, concurrency_build_index);
return block;
join->insertFromBlock(block, concurrency_build_index);
return block;
}
catch (...)
{
join->meetError();
throw;
}
}

void HashJoinBuildBlockInputStream::appendInfo(FmtBuffer & buffer) const
Expand Down
26 changes: 17 additions & 9 deletions dbms/src/DataStreams/HashJoinProbeBlockInputStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,27 @@ void HashJoinProbeBlockInputStream::cancel(bool kill)

Block HashJoinProbeBlockInputStream::readImpl()
{
// if join finished, return {} directly.
if (squashing_transform.isJoinFinished())
try
{
return Block{};
}
// if join finished, return {} directly.
if (squashing_transform.isJoinFinished())
{
return Block{};
}

while (squashing_transform.needAppendBlock())
while (squashing_transform.needAppendBlock())
{
Block result_block = getOutputBlock();
squashing_transform.appendBlock(result_block);
}
auto ret = squashing_transform.getFinalOutputBlock();
return ret;
}
catch (...)
{
Block result_block = getOutputBlock();
squashing_transform.appendBlock(result_block);
join->meetError();
throw;
}
auto ret = squashing_transform.getFinalOutputBlock();
return ret;
}

void HashJoinProbeBlockInputStream::readSuffixImpl()
Expand Down
2 changes: 2 additions & 0 deletions dbms/src/Debug/DBGInvoker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ DBGInvoker::DBGInvoker()
regSchemalessFunc("gc_global_storage_pool", dbgFuncTriggerGlobalPageStorageGC);

regSchemalessFunc("read_index_stress_test", ReadIndexStressTest::dbgFuncStressTest);

regSchemalessFunc("get_active_threads_in_dynamic_thread_pool", dbgFuncActiveThreadsInDynamicThreadPool);
}

void replaceSubstr(std::string & str, const std::string & target, const std::string & replacement)
Expand Down
14 changes: 14 additions & 0 deletions dbms/src/Debug/dbgFuncMisc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <Common/DynamicThreadPool.h>
#include <Common/typeid_cast.h>
#include <Debug/dbgFuncMisc.h>
#include <Interpreters/Context.h>
Expand Down Expand Up @@ -130,4 +131,17 @@ void dbgFuncTriggerGlobalPageStorageGC(Context & context, const ASTs & /*args*/,
global_storage_pool->gc();
}
}

void dbgFuncActiveThreadsInDynamicThreadPool(Context &, const ASTs &, DBGInvoker::Printer output)
{
if (DynamicThreadPool::global_instance)
{
auto value = GET_METRIC(tiflash_thread_count, type_active_threads_of_thdpool).Value();
output(std::to_string(static_cast<Int64>(value)));
}
else
{
output("0");
}
}
} // namespace DB
4 changes: 3 additions & 1 deletion dbms/src/Debug/dbgFuncMisc.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

namespace DB
{

class Context;

// Find the last occurence of `key` in log file and extract the first number follow the key.
Expand All @@ -33,4 +32,7 @@ void dbgFuncSearchLogForKey(Context & context, const ASTs & args, DBGInvoker::Pr
// ./storage-client.sh "DBGInvoke trigger_global_storage_pool_gc()"
void dbgFuncTriggerGlobalPageStorageGC(Context & context, const ASTs & args, DBGInvoker::Printer output);

// Get active threads in dynamic thread pool, if dynamic thread pool is disabled, return 0
void dbgFuncActiveThreadsInDynamicThreadPool(Context & context, const ASTs & /*args*/, DBGInvoker::Printer /*output*/);

} // namespace DB
67 changes: 55 additions & 12 deletions dbms/src/Interpreters/Join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ namespace FailPoints
{
extern const char random_join_build_failpoint[];
extern const char random_join_prob_failpoint[];
extern const char exception_mpp_hash_build[];
extern const char exception_mpp_hash_probe[];
} // namespace FailPoints

namespace ErrorCodes
Expand Down Expand Up @@ -135,6 +137,7 @@ Join::Join(
, key_names_left(key_names_left_)
, key_names_right(key_names_right_)
, build_concurrency(0)
, active_build_concurrency(0)
, probe_concurrency(0)
, active_probe_concurrency(0)
, collators(collators_)
Expand All @@ -145,7 +148,6 @@ Join::Join(
, other_condition_ptr(other_condition_ptr_)
, original_strictness(strictness)
, max_block_size_for_cross_join(max_block_size_)
, build_table_state(BuildTableState::SUCCEED)
, log(Logger::get(req_id))
, enable_fine_grained_shuffle(enable_fine_grained_shuffle_)
, fine_grained_shuffle_count(fine_grained_shuffle_count_)
Expand All @@ -165,11 +167,14 @@ Join::Join(
LOG_INFO(log, "FineGrainedShuffle flag {}, stream count {}", enable_fine_grained_shuffle, fine_grained_shuffle_count);
}

void Join::setBuildTableState(BuildTableState state_)
void Join::meetError()
{
std::lock_guard lk(build_table_mutex);
build_table_state = state_;
build_table_cv.notify_all();
std::lock_guard lk(build_probe_mutex);
if (meet_error)
return;
meet_error = true;
build_cv.notify_all();
probe_cv.notify_all();
}

bool CanAsColumnString(const IColumn * column)
Expand Down Expand Up @@ -465,6 +470,8 @@ void Join::setBuildConcurrencyAndInitPool(size_t build_concurrency_)
{
if (unlikely(build_concurrency > 0))
throw Exception("Logical error: `setBuildConcurrencyAndInitPool` shouldn't be called more than once", ErrorCodes::LOGICAL_ERROR);
/// do not set active_build_concurrency because in compile stage, `joinBlock` will be called to get generate header, if active_build_concurrency
/// is set here, `joinBlock` will hang when used to get header
build_concurrency = std::max(1, build_concurrency_);

for (size_t i = 0; i < getBuildConcurrencyInternal(); ++i)
Expand Down Expand Up @@ -1992,16 +1999,52 @@ void Join::checkTypesOfKeys(const Block & block_left, const Block & block_right)
}
}

Block Join::joinBlock(ProbeProcessInfo & probe_process_info) const
void Join::finishOneProbe()
{
// ck will use this function to generate header, that's why here is a check.
std::unique_lock lock(build_probe_mutex);
if (active_probe_concurrency == 1)
{
std::unique_lock lk(build_table_mutex);

build_table_cv.wait(lk, [&]() { return build_table_state != BuildTableState::WAITING; });
if (build_table_state == BuildTableState::FAILED) /// throw this exception once failed to build the hash table
throw Exception("Build failed before join probe!");
FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_mpp_hash_probe);
}
--active_probe_concurrency;
if (active_probe_concurrency == 0)
probe_cv.notify_all();
}
void Join::finishOneBuild()
{
std::unique_lock lock(build_probe_mutex);
if (active_build_concurrency == 1)
{
FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_mpp_hash_build);
}
--active_build_concurrency;
if (active_build_concurrency == 0)
build_cv.notify_all();
}

void Join::waitUntilAllProbeFinished() const
{
std::unique_lock lock(build_probe_mutex);
probe_cv.wait(lock, [&]() {
return meet_error || active_probe_concurrency == 0;
});
if (meet_error)
throw Exception("Join meet error before all probe finished!");
}

void Join::waitUntilAllBuildFinished() const
{
std::unique_lock lock(build_probe_mutex);
build_cv.wait(lock, [&]() {
return meet_error || active_build_concurrency == 0;
});
if (meet_error)
throw Exception("Build failed before join probe!");
}

Block Join::joinBlock(ProbeProcessInfo & probe_process_info) const
{
waitUntilAllBuildFinished();

std::shared_lock lock(rwlock);

Expand Down
49 changes: 20 additions & 29 deletions dbms/src/Interpreters/Join.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,45 +146,35 @@ class Join

const Names & getLeftJoinKeys() const { return key_names_left; }

void setInitActiveBuildConcurrency()
{
std::unique_lock lock(build_probe_mutex);
active_build_concurrency = getBuildConcurrencyInternal();
}
void finishOneBuild();
void waitUntilAllBuildFinished() const;

size_t getProbeConcurrency() const
{
std::unique_lock lock(probe_mutex);
std::unique_lock lock(build_probe_mutex);
return probe_concurrency;
}
void setProbeConcurrency(size_t concurrency)
{
std::unique_lock lock(probe_mutex);
std::unique_lock lock(build_probe_mutex);
probe_concurrency = concurrency;
active_probe_concurrency = probe_concurrency;
}
void finishOneProbe()
{
std::unique_lock lock(probe_mutex);
active_probe_concurrency--;
if (active_probe_concurrency == 0)
probe_cv.notify_all();
}
void waitUntilAllProbeFinished()
{
std::unique_lock lock(probe_mutex);
probe_cv.wait(lock, [&]() {
return active_probe_concurrency == 0;
});
}
void finishOneProbe();
void waitUntilAllProbeFinished() const;

size_t getBuildConcurrency() const
{
std::shared_lock lock(rwlock);
return getBuildConcurrencyInternal();
}

enum BuildTableState
{
WAITING,
FAILED,
SUCCEED
};
void setBuildTableState(BuildTableState state_);
void meetError();

/// Reference to the row in block.
struct RowRef
Expand Down Expand Up @@ -303,13 +293,18 @@ class Join
/// Names of key columns (columns for equi-JOIN) in "right" table (in the order they appear in USING clause).
const Names key_names_right;

mutable std::mutex build_probe_mutex;

mutable std::condition_variable build_cv;
size_t build_concurrency;
size_t active_build_concurrency;

mutable std::mutex probe_mutex;
std::condition_variable probe_cv;
mutable std::condition_variable probe_cv;
size_t probe_concurrency;
size_t active_probe_concurrency;

bool meet_error = false;

private:
/// collators for the join key
const TiDB::TiDBCollators collators;
Expand Down Expand Up @@ -355,10 +350,6 @@ class Join
/// Block with key columns in the same order they appear in the right-side table.
Block sample_block_with_keys;

mutable std::mutex build_table_mutex;
mutable std::condition_variable build_table_cv;
BuildTableState build_table_state;

const LoggerPtr log;

Block totals;
Expand Down
Loading

0 comments on commit bca56b9

Please sign in to comment.