Skip to content

Commit

Permalink
Minor refine of join (#7257)
Browse files Browse the repository at this point in the history
ref #6233
  • Loading branch information
windtalker authored Apr 10, 2023
1 parent cf4802c commit 6549777
Show file tree
Hide file tree
Showing 10 changed files with 260 additions and 260 deletions.
8 changes: 8 additions & 0 deletions dbms/src/Columns/ColumnUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <Columns/ColumnNullable.h>
#include <Columns/ColumnUtils.h>
#include <DataTypes/DataTypeNullable.h>

namespace DB
{
Expand All @@ -30,4 +32,10 @@ bool columnEqual(const ColumnPtr & expected, const ColumnPtr & actual, String &
}
return true;
}
void convertColumnToNullable(ColumnWithTypeAndName & column)
{
column.type = makeNullable(column.type);
if (column.column)
column.column = makeNullable(column.column);
}
} // namespace DB
2 changes: 2 additions & 0 deletions dbms/src/Columns/ColumnUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
#pragma once

#include <Columns/IColumn.h>
#include <Core/ColumnWithTypeAndName.h>

namespace DB
{
bool columnEqual(const ColumnPtr & expected, const ColumnPtr & actual, String & unequal_msg);
void convertColumnToNullable(ColumnWithTypeAndName & column);
} // namespace DB
17 changes: 9 additions & 8 deletions dbms/src/DataStreams/NonJoinedBlockInputStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <Columns/ColumnUtils.h>
#include <DataStreams/NonJoinedBlockInputStream.h>
#include <DataStreams/materializeBlock.h>

Expand Down Expand Up @@ -200,12 +201,12 @@ void NonJoinedBlockInputStream::fillColumnsUsingCurrentPartition(
}
if (parent.strictness == ASTTableJoin::Strictness::Any)
{
switch (parent.type)
switch (parent.join_map_method)
{
#define M(TYPE) \
case JoinType::TYPE: \
#define M(METHOD) \
case JoinMapMethod::METHOD: \
fillColumns<ASTTableJoin::Strictness::Any>( \
*partition->maps_any_full.TYPE, \
*partition->maps_any_full.METHOD, \
num_columns_left, \
mutable_columns_left, \
num_columns_right, \
Expand All @@ -221,12 +222,12 @@ void NonJoinedBlockInputStream::fillColumnsUsingCurrentPartition(
}
else if (parent.strictness == ASTTableJoin::Strictness::All)
{
switch (parent.type)
switch (parent.join_map_method)
{
#define M(TYPE) \
case JoinType::TYPE: \
#define M(METHOD) \
case JoinMapMethod::METHOD: \
fillColumns<ASTTableJoin::Strictness::All>( \
*partition->maps_all_full.TYPE, \
*partition->maps_all_full.METHOD, \
num_columns_left, \
mutable_columns_left, \
num_columns_right, \
Expand Down
13 changes: 1 addition & 12 deletions dbms/src/Interpreters/Expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <Columns/ColumnNullable.h>
#include <DataTypes/DataTypeNullable.h>
#include <Columns/ColumnUtils.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionHelpers.h>
#include <Interpreters/Expand.h>
Expand All @@ -22,16 +21,6 @@
namespace DB
{

namespace
{
void convertColumnToNullable(ColumnWithTypeAndName & column)
{
column.type = makeNullable(column.type);
if (column.column)
column.column = makeNullable(column.column);
}
} // namespace

Expand::Expand(const DB::GroupingSets & gss)
: group_sets_names(gss)
{
Expand Down
124 changes: 7 additions & 117 deletions dbms/src/Interpreters/Join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <Columns/ColumnConst.h>
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnUtils.h>
#include <Common/ColumnsHashing.h>
#include <Common/FailPoint.h>
#include <Common/typeid_cast.h>
Expand Down Expand Up @@ -111,13 +108,6 @@ ColumnRawPtrs extractAndMaterializeKeyColumns(const Block & block, Columns & mat
const std::string Join::match_helper_prefix = "__left-semi-join-match-helper";
const DataTypePtr Join::match_helper_type = makeNullable(std::make_shared<DataTypeInt8>());

void convertColumnToNullable(ColumnWithTypeAndName & column)
{
column.type = makeNullable(column.type);
if (column.column)
column.column = makeNullable(column.column);
}

Join::Join(
const Names & key_names_left_,
const Names & key_names_right_,
Expand Down Expand Up @@ -191,97 +181,11 @@ void Join::meetErrorImpl(const String & error_message_, std::unique_lock<std::mu
probe_cv.notify_all();
}

bool CanAsColumnString(const IColumn * column)
{
return typeid_cast<const ColumnString *>(column)
|| (column->isColumnConst() && typeid_cast<const ColumnString *>(&static_cast<const ColumnConst *>(column)->getDataColumn()));
}

JoinType Join::chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_sizes_) const
{
const size_t keys_size = key_columns.size();

if (keys_size == 0)
return JoinType::CROSS;

bool all_fixed = true;
size_t keys_bytes = 0;
key_sizes_.resize(keys_size);
for (size_t j = 0; j < keys_size; ++j)
{
if (!key_columns[j]->isFixedAndContiguous())
{
all_fixed = false;
break;
}
key_sizes_[j] = key_columns[j]->sizeOfValueIfFixed();
keys_bytes += key_sizes_[j];
}

/// If there is one numeric key that fits in 64 bits
if (keys_size == 1 && key_columns[0]->isNumeric())
{
size_t size_of_field = key_columns[0]->sizeOfValueIfFixed();
if (size_of_field == 1)
return JoinType::key8;
if (size_of_field == 2)
return JoinType::key16;
if (size_of_field == 4)
return JoinType::key32;
if (size_of_field == 8)
return JoinType::key64;
if (size_of_field == 16)
return JoinType::keys128;
throw Exception("Logical error: numeric column has sizeOfField not in 1, 2, 4, 8, 16.", ErrorCodes::LOGICAL_ERROR);
}

/// If the keys fit in N bits, we will use a hash table for N-bit-packed keys
if (all_fixed && keys_bytes <= 16)
return JoinType::keys128;
if (all_fixed && keys_bytes <= 32)
return JoinType::keys256;

/// If there is single string key, use hash table of it's values.
if (keys_size == 1 && CanAsColumnString(key_columns[0]))
{
if (collators.empty() || !collators[0])
return JoinType::key_strbin;
else
{
switch (collators[0]->getCollatorType())
{
case TiDB::ITiDBCollator::CollatorType::UTF8MB4_BIN:
case TiDB::ITiDBCollator::CollatorType::UTF8_BIN:
case TiDB::ITiDBCollator::CollatorType::LATIN1_BIN:
case TiDB::ITiDBCollator::CollatorType::ASCII_BIN:
{
return JoinType::key_strbinpadding;
}
case TiDB::ITiDBCollator::CollatorType::BINARY:
{
return JoinType::key_strbin;
}
default:
{
// for CI COLLATION, use original way
return JoinType::key_string;
}
}
}
}

if (keys_size == 1 && typeid_cast<const ColumnFixedString *>(key_columns[0]))
return JoinType::key_fixed_string;

/// Otherwise, use serialized values as the key.
return JoinType::serialized;
}

size_t Join::getTotalRowCount() const
{
size_t res = 0;

if (type == JoinType::CROSS)
if (join_map_method == JoinMapMethod::CROSS)
{
res = total_input_build_rows;
}
Expand All @@ -304,7 +208,7 @@ size_t Join::getTotalByteCount()
}
else
{
if (type == JoinType::CROSS)
if (join_map_method == JoinMapMethod::CROSS)
{
for (const auto & block : blocks)
res += block.bytes();
Expand Down Expand Up @@ -344,7 +248,7 @@ void Join::setBuildConcurrencyAndInitJoinPartition(size_t build_concurrency_)
partitions.reserve(build_concurrency);
for (size_t i = 0; i < getBuildConcurrency(); ++i)
{
partitions.push_back(std::make_unique<JoinPartition>(type, kind, strictness, max_block_size, log));
partitions.push_back(std::make_unique<JoinPartition>(join_map_method, kind, strictness, max_block_size, log));
}
}

Expand Down Expand Up @@ -412,13 +316,13 @@ void Join::initBuild(const Block & sample_block, size_t build_concurrency_)
if (unlikely(initialized))
throw Exception("Logical error: Join has been initialized", ErrorCodes::LOGICAL_ERROR);
initialized = true;
type = chooseMethod(getKeyColumns(key_names_right, sample_block), key_sizes);
join_map_method = chooseJoinMapMethod(getKeyColumns(key_names_right, sample_block), key_sizes, collators);
setBuildConcurrencyAndInitJoinPartition(build_concurrency_);
build_sample_block = sample_block;
build_spiller = std::make_unique<Spiller>(build_spill_config, false, build_concurrency_, build_sample_block, log);
if (max_bytes_before_external_join > 0)
{
if (type == JoinType::CROSS)
if (join_map_method == JoinMapMethod::CROSS)
{
/// todo support spill for cross join
max_bytes_before_external_join = 0;
Expand Down Expand Up @@ -1353,24 +1257,10 @@ Block Join::joinBlockNullAware(ProbeProcessInfo & probe_process_info) const
{
Block block = probe_process_info.block;

size_t keys_size = key_names_left.size();
ColumnRawPtrs key_columns(keys_size);

/// Rare case, when keys are constant. To avoid code bloat, simply materialize them.
/// Note: this variable can't be removed because it will take smart pointers' lifecycle to the end of this function.
Columns materialized_columns;

/// Memoize key columns to work with.
for (size_t i = 0; i < keys_size; ++i)
{
key_columns[i] = block.getByName(key_names_left[i]).column.get();

if (ColumnPtr converted = key_columns[i]->convertToFullColumnIfConst())
{
materialized_columns.emplace_back(converted);
key_columns[i] = materialized_columns.back().get();
}
}
ColumnRawPtrs key_columns = extractAndMaterializeKeyColumns(block, materialized_columns, key_names_left);

/// Note that `extractAllKeyNullMap` must be done before `extractNestedColumnsAndNullMap`
/// because `extractNestedColumnsAndNullMap` will change the nullable column to its nested column.
Expand Down
8 changes: 2 additions & 6 deletions dbms/src/Interpreters/Join.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include <DataStreams/IBlockInputStream.h>
#include <Flash/Coprocessor/JoinInterpreterHelper.h>
#include <Interpreters/ExpressionActions.h>
#include <Interpreters/JoinHashTable.h>
#include <Interpreters/JoinHashMap.h>
#include <Interpreters/JoinPartition.h>
#include <Interpreters/SettingsCommon.h>

Expand Down Expand Up @@ -291,9 +291,7 @@ class Join
bool has_build_data_in_memory = false;

private:
JoinType type = JoinType::EMPTY;

JoinType chooseMethod(const ColumnRawPtrs & key_columns, Sizes & key_sizes) const;
JoinMapMethod join_map_method = JoinMapMethod::EMPTY;

Sizes key_sizes;

Expand Down Expand Up @@ -395,6 +393,4 @@ struct RestoreInfo
, probe_stream(probe_stream_){};
};

void convertColumnToNullable(ColumnWithTypeAndName & column);

} // namespace DB
Loading

0 comments on commit 6549777

Please sign in to comment.