diff --git a/dbms/src/DataStreams/LimitByBlockInputStream.cpp b/dbms/src/DataStreams/LimitByBlockInputStream.cpp deleted file mode 100644 index 83e93041c34..00000000000 --- a/dbms/src/DataStreams/LimitByBlockInputStream.cpp +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2022 PingCAP, Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - - -namespace DB -{ - -LimitByBlockInputStream::LimitByBlockInputStream(const BlockInputStreamPtr & input, size_t group_size_, const Names & columns) - : columns_names(columns) - , group_size(group_size_) -{ - children.push_back(input); -} - -Block LimitByBlockInputStream::readImpl() -{ - /// Execute until end of stream or until - /// a block with some new records will be gotten. - while (true) - { - Block block = children[0]->read(); - if (!block) - return Block(); - - const ColumnRawPtrs column_ptrs(getKeyColumns(block)); - const size_t rows = block.rows(); - IColumn::Filter filter(rows); - size_t inserted_count = 0; - - for (size_t i = 0; i < rows; ++i) - { - UInt128 key; - SipHash hash; - - for (auto & column : column_ptrs) - column->updateHashWithValue(i, hash); - - hash.get128(key); - - if (keys_counts[key]++ < group_size) - { - inserted_count++; - filter[i] = 1; - } - else - filter[i] = 0; - } - - /// Just go to the next block if there isn't any new records in the current one. - if (!inserted_count) - continue; - - size_t all_columns = block.columns(); - for (size_t i = 0; i < all_columns; ++i) - block.safeGetByPosition(i).column = block.safeGetByPosition(i).column->filter(filter, inserted_count); - - return block; - } -} - -ColumnRawPtrs LimitByBlockInputStream::getKeyColumns(Block & block) const -{ - ColumnRawPtrs column_ptrs; - column_ptrs.reserve(columns_names.size()); - - for (const auto & name : columns_names) - { - auto & column = block.getByName(name).column; - - /// Ignore all constant columns. - if (!column->isColumnConst()) - column_ptrs.emplace_back(column.get()); - } - - return column_ptrs; -} - -} diff --git a/dbms/src/DataStreams/LimitByBlockInputStream.h b/dbms/src/DataStreams/LimitByBlockInputStream.h deleted file mode 100644 index 4a91f0ca9cc..00000000000 --- a/dbms/src/DataStreams/LimitByBlockInputStream.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2022 PingCAP, Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -namespace DB -{ - -/** Implements LIMIT BY clause witch can be used to obtain a "top N by subgroup". - * - * For example, if you have table T like this (Num: 1 1 3 3 3 4 4 5 7 7 7 7), - * the query SELECT Num FROM T LIMIT 2 BY Num - * will give you the following result: (Num: 1 1 3 3 4 4 5 7 7). - */ -class LimitByBlockInputStream : public IProfilingBlockInputStream -{ -public: - LimitByBlockInputStream(const BlockInputStreamPtr & input, size_t group_size_, const Names & columns); - - String getName() const override { return "LimitBy"; } - - Block getHeader() const override { return children.at(0)->getHeader(); } - -protected: - Block readImpl() override; - -private: - ColumnRawPtrs getKeyColumns(Block & block) const; - -private: - using MapHashed = HashMap; - - const Names columns_names; - const size_t group_size; - MapHashed keys_counts; -}; - -} diff --git a/dbms/src/DataStreams/TiRemoteBlockInputStream.h b/dbms/src/DataStreams/TiRemoteBlockInputStream.h index a0b90464dff..124f08d65c4 100644 --- a/dbms/src/DataStreams/TiRemoteBlockInputStream.h +++ b/dbms/src/DataStreams/TiRemoteBlockInputStream.h @@ -19,8 +19,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -50,15 +50,10 @@ class TiRemoteBlockInputStream : public IProfilingBlockInputStream String name; - /// this atomic variable is kind of a lock for the struct of execution_summaries: - /// if execution_summaries_inited[index] = true, the map execution_summaries[index] - /// itself will not be modified, so ExecutionSummaryCollector can read it safely, otherwise, - /// ExecutionSummaryCollector will just skip execution_summaries[index] - std::vector> execution_summaries_inited; - std::vector> execution_summaries; - const LoggerPtr log; + RemoteExecutionSummary remote_execution_summary; + uint64_t total_rows; // For fine grained shuffle, sender will partition data into muiltiple streams by hashing. @@ -68,64 +63,6 @@ class TiRemoteBlockInputStream : public IProfilingBlockInputStream std::unique_ptr decoder_ptr; - void initRemoteExecutionSummaries(tipb::SelectResponse & resp, size_t index) - { - for (const auto & execution_summary : resp.execution_summaries()) - { - if (likely(execution_summary.has_executor_id())) - { - auto & remote_execution_summary = execution_summaries[index][execution_summary.executor_id()]; - remote_execution_summary.time_processed_ns = execution_summary.time_processed_ns(); - remote_execution_summary.num_produced_rows = execution_summary.num_produced_rows(); - remote_execution_summary.num_iterations = execution_summary.num_iterations(); - remote_execution_summary.concurrency = execution_summary.concurrency(); - DM::ScanContext scan_context; - scan_context.deserialize(execution_summary.tiflash_scan_context()); - remote_execution_summary.scan_context->merge(scan_context); - } - } - execution_summaries_inited[index].store(true); - } - - void addRemoteExecutionSummaries(tipb::SelectResponse & resp, size_t index) - { - if (unlikely(resp.execution_summaries_size() == 0)) - return; - - if (!execution_summaries_inited[index].load()) - { - initRemoteExecutionSummaries(resp, index); - return; - } - if constexpr (is_streaming_reader) - throw Exception( - fmt::format( - "There are more than one execution summary packet of index {} in streaming reader, " - "this should not happen", - index)); - auto & execution_summaries_map = execution_summaries[index]; - for (const auto & execution_summary : resp.execution_summaries()) - { - if (likely(execution_summary.has_executor_id())) - { - const auto & executor_id = execution_summary.executor_id(); - if (unlikely(execution_summaries_map.find(executor_id) == execution_summaries_map.end())) - { - LOG_WARNING(log, "execution {} not found in execution_summaries, this should not happen", executor_id); - continue; - } - auto & remote_execution_summary = execution_summaries_map[executor_id]; - remote_execution_summary.time_processed_ns = std::max(remote_execution_summary.time_processed_ns, execution_summary.time_processed_ns()); - remote_execution_summary.num_produced_rows += execution_summary.num_produced_rows(); - remote_execution_summary.num_iterations += execution_summary.num_iterations(); - remote_execution_summary.concurrency += execution_summary.concurrency(); - DM::ScanContext scan_context; - scan_context.deserialize(execution_summary.tiflash_scan_context()); - remote_execution_summary.scan_context->merge(scan_context); - } - } - } - bool fetchRemoteResult() { while (true) @@ -147,14 +84,13 @@ class TiRemoteBlockInputStream : public IProfilingBlockInputStream throw Exception(result.resp->error().DebugString()); } - size_t index = 0; - if constexpr (is_streaming_reader) - index = result.call_index; - /// only the last response contains execution summaries if (result.resp != nullptr) - addRemoteExecutionSummaries(*result.resp, index); + remote_execution_summary.add(*result.resp); + size_t index = 0; + if constexpr (is_streaming_reader) + index = result.call_index; const auto & decode_detail = result.decode_detail; auto & connection_profile_info = connection_profile_infos[index]; connection_profile_info.packets += decode_detail.packets; @@ -179,16 +115,10 @@ class TiRemoteBlockInputStream : public IProfilingBlockInputStream : remote_reader(remote_reader_) , source_num(remote_reader->getSourceNum()) , name(fmt::format("TiRemote({})", RemoteReader::name)) - , execution_summaries_inited(source_num) , log(Logger::get(name, req_id, executor_id)) , total_rows(0) , stream_id(stream_id_) { - for (size_t i = 0; i < source_num; ++i) - { - execution_summaries_inited[i].store(false); - } - execution_summaries.resize(source_num); connection_profile_infos.resize(source_num); sample_block = Block(getColumnWithTypeAndName(toNamesAndTypes(remote_reader->getOutputSchema()))); static constexpr size_t squash_rows_limit = 8192; @@ -228,9 +158,9 @@ class TiRemoteBlockInputStream : public IProfilingBlockInputStream return block; } - const std::unordered_map * getRemoteExecutionSummaries(size_t index) + const RemoteExecutionSummary & getRemoteExecutionSummary() { - return execution_summaries_inited[index].load() ? &execution_summaries[index] : nullptr; + return remote_execution_summary; } size_t getTotalRows() const { return total_rows; } diff --git a/dbms/src/DataTypes/DataTypeNullable.cpp b/dbms/src/DataTypes/DataTypeNullable.cpp index fbcda5065c3..c99ee2edf83 100644 --- a/dbms/src/DataTypes/DataTypeNullable.cpp +++ b/dbms/src/DataTypes/DataTypeNullable.cpp @@ -67,12 +67,12 @@ void DataTypeNullable::serializeBinaryBulkWithMultipleStreams( bool position_independent_encoding, SubstreamPath & path) const { - const ColumnNullable & col = static_cast(column); + const auto & col = static_cast(column); col.checkConsistency(); /// First serialize null map. path.push_back(Substream::NullMap); - if (auto stream = getter(path)) + if (auto * stream = getter(path)) DataTypeUInt8().serializeBinaryBulk(col.getNullMapColumn(), *stream, offset, limit); /// Then serialize contents of arrays. @@ -89,10 +89,10 @@ void DataTypeNullable::deserializeBinaryBulkWithMultipleStreams( bool position_independent_encoding, SubstreamPath & path) const { - ColumnNullable & col = static_cast(column); + auto & col = static_cast(column); path.push_back(Substream::NullMap); - if (auto stream = getter(path)) + if (auto * stream = getter(path)) DataTypeUInt8().deserializeBinaryBulk(col.getNullMapColumn(), *stream, limit, 0); path.back() = Substream::NullableElements; @@ -100,50 +100,9 @@ void DataTypeNullable::deserializeBinaryBulkWithMultipleStreams( } -void DataTypeNullable::serializeWidenBinaryBulkWithMultipleStreams( - const IColumn & column, - const OutputStreamGetter & getter, - size_t offset, - size_t limit, - bool position_independent_encoding, - SubstreamPath & path) const -{ - const ColumnNullable & col = static_cast(column); - col.checkConsistency(); - - /// First serialize null map. - path.push_back(Substream::NullMap); - if (auto stream = getter(path)) - DataTypeUInt8().serializeBinaryBulk(col.getNullMapColumn(), *stream, offset, limit); - - /// Then serialize contents of arrays. - path.back() = Substream::NullableElements; - nested_data_type->serializeWidenBinaryBulkWithMultipleStreams(col.getNestedColumn(), getter, offset, limit, position_independent_encoding, path); -} - - -void DataTypeNullable::deserializeWidenBinaryBulkWithMultipleStreams( - IColumn & column, - const InputStreamGetter & getter, - size_t limit, - double avg_value_size_hint, - bool position_independent_encoding, - SubstreamPath & path) const -{ - ColumnNullable & col = static_cast(column); - - path.push_back(Substream::NullMap); - if (auto stream = getter(path)) - DataTypeUInt8().deserializeBinaryBulk(col.getNullMapColumn(), *stream, limit, 0); - - path.back() = Substream::NullableElements; - nested_data_type->deserializeWidenBinaryBulkWithMultipleStreams(col.getNestedColumn(), getter, limit, avg_value_size_hint, position_independent_encoding, path); -} - - void DataTypeNullable::serializeBinary(const IColumn & column, size_t row_num, WriteBuffer & ostr) const { - const ColumnNullable & col = static_cast(column); + const auto & col = static_cast(column); bool is_null = col.isNullAt(row_num); writeBinary(is_null, ostr); @@ -159,7 +118,7 @@ static void safeDeserialize( CheckForNull && check_for_null, DeserializeNested && deserialize_nested) { - ColumnNullable & col = static_cast(column); + auto & col = static_cast(column); if (check_for_null()) { @@ -186,14 +145,14 @@ void DataTypeNullable::deserializeBinary(IColumn & column, ReadBuffer & istr) co { safeDeserialize( column, - [&istr] { bool is_null = 0; readBinary(is_null, istr); return is_null; }, + [&istr] { bool is_null = false; readBinary(is_null, istr); return is_null; }, [this, &istr](IColumn & nested) { nested_data_type->deserializeBinary(nested, istr); }); } void DataTypeNullable::serializeTextEscaped(const IColumn & column, size_t row_num, WriteBuffer & ostr) const { - const ColumnNullable & col = static_cast(column); + const auto & col = static_cast(column); if (col.isNullAt(row_num)) writeCString("\\N", ostr); @@ -261,7 +220,7 @@ void DataTypeNullable::deserializeTextEscaped(IColumn & column, ReadBuffer & ist void DataTypeNullable::serializeTextQuoted(const IColumn & column, size_t row_num, WriteBuffer & ostr) const { - const ColumnNullable & col = static_cast(column); + const auto & col = static_cast(column); if (col.isNullAt(row_num)) writeCString("NULL", ostr); @@ -280,7 +239,7 @@ void DataTypeNullable::deserializeTextQuoted(IColumn & column, ReadBuffer & istr void DataTypeNullable::serializeTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr) const { - const ColumnNullable & col = static_cast(column); + const auto & col = static_cast(column); if (col.isNullAt(row_num)) writeCString("\\N", ostr); @@ -298,7 +257,7 @@ void DataTypeNullable::deserializeTextCSV(IColumn & column, ReadBuffer & istr, c void DataTypeNullable::serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr) const { - const ColumnNullable & col = static_cast(column); + const auto & col = static_cast(column); if (col.isNullAt(row_num)) writeCString("NULL", ostr); @@ -308,7 +267,7 @@ void DataTypeNullable::serializeText(const IColumn & column, size_t row_num, Wri void DataTypeNullable::serializeTextJSON(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettingsJSON & settings) const { - const ColumnNullable & col = static_cast(column); + const auto & col = static_cast(column); if (col.isNullAt(row_num)) writeCString("null", ostr); @@ -326,7 +285,7 @@ void DataTypeNullable::deserializeTextJSON(IColumn & column, ReadBuffer & istr) void DataTypeNullable::serializeTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr) const { - const ColumnNullable & col = static_cast(column); + const auto & col = static_cast(column); if (col.isNullAt(row_num)) writeCString("\\N", ostr); diff --git a/dbms/src/DataTypes/DataTypeNullable.h b/dbms/src/DataTypes/DataTypeNullable.h index 711aabfb905..1c7d5c9d2ab 100644 --- a/dbms/src/DataTypes/DataTypeNullable.h +++ b/dbms/src/DataTypes/DataTypeNullable.h @@ -50,22 +50,6 @@ class DataTypeNullable final : public IDataType bool position_independent_encoding, SubstreamPath & path) const override; - void serializeWidenBinaryBulkWithMultipleStreams( - const IColumn & column, - const OutputStreamGetter & getter, - size_t offset, - size_t limit, - bool position_independent_encoding, - SubstreamPath & path) const override; - - void deserializeWidenBinaryBulkWithMultipleStreams( - IColumn & column, - const InputStreamGetter & getter, - size_t limit, - double avg_value_size_hint, - bool position_independent_encoding, - SubstreamPath & path) const override; - void serializeBinary(const Field & field, WriteBuffer & ostr) const override { nested_data_type->serializeBinary(field, ostr); } void deserializeBinary(Field & field, ReadBuffer & istr) const override { nested_data_type->deserializeBinary(field, istr); } void serializeBinary(const IColumn & column, size_t row_num, WriteBuffer & ostr) const override; diff --git a/dbms/src/DataTypes/DataTypeNumberBase.cpp b/dbms/src/DataTypes/DataTypeNumberBase.cpp index 768e6aa7d1f..9b90a145852 100644 --- a/dbms/src/DataTypes/DataTypeNumberBase.cpp +++ b/dbms/src/DataTypes/DataTypeNumberBase.cpp @@ -42,7 +42,7 @@ void DataTypeNumberBase::serializeTextEscaped(const IColumn & column, size_t template static void deserializeText(IColumn & column, ReadBuffer & istr) { - T x; + T x{}; if constexpr (std::is_integral_v && std::is_arithmetic_v) readIntTextUnsafe(x, istr); @@ -198,7 +198,7 @@ template void DataTypeNumberBase::serializeBinary(const Field & field, WriteBuffer & ostr) const { /// ColumnVector::value_type is a narrower type. For example, UInt8, when the Field type is UInt64 - typename ColumnVector::value_type x = get::Type>(field); + auto x = get::Type>(field); writeBinary(x, ostr); } @@ -214,15 +214,6 @@ template void DataTypeNumberBase::serializeBinary(const IColumn & column, size_t row_num, WriteBuffer & ostr) const { writeBinary(static_cast &>(column).getData()[row_num], ostr); - // if (likely(widened)) - // { - // using WidestType = typename NearestFieldType::Type; - // writeBinary(static_cast(static_cast &>(column).getData()[row_num]), ostr); - // } - // else - // { - // writeBinary(static_cast &>(column).getData()[row_num], ostr); - // } } template @@ -231,18 +222,6 @@ void DataTypeNumberBase::deserializeBinary(IColumn & column, ReadBuffer & ist typename ColumnVector::value_type x; readBinary(x, istr); static_cast &>(column).getData().push_back(x); - // if (likely(widened)) - // { - // using WidestType = typename NearestFieldType::Type; - // typename ColumnVector::value_type y; - // readBinary(y, istr); - // x = static_cast(y); - // } - // else - // { - // readBinary(x, istr); - // } - // static_cast &>(column).getData().push_back(x); } template @@ -268,51 +247,6 @@ void DataTypeNumberBase::deserializeBinaryBulk(IColumn & column, ReadBuffer & x.resize(initial_size + size / sizeof(typename ColumnVector::value_type)); } -template -void DataTypeNumberBase::serializeWidenBinaryBulk(const IColumn & column, WriteBuffer & ostr, size_t offset, size_t limit) const -{ - if (!widened) - return serializeBinaryBulk(column, ostr, offset, limit); - - const typename ColumnVector::Container & x = typeid_cast &>(column).getData(); - - size_t size = x.size(); - - if (limit == 0 || offset + limit > size) - limit = size - offset; - - using WidestType = typename NearestFieldType::Type; - typename ColumnVector::Container y(limit); - for (size_t i = 0; i < limit; i++) - { - y[i] = static_cast(x[offset + i]); - } - - ostr.write(reinterpret_cast(&y[0]), sizeof(typename ColumnVector::value_type) * limit); -} - -template -void DataTypeNumberBase::deserializeWidenBinaryBulk(IColumn & column, ReadBuffer & istr, size_t limit, double avg_value_size_hint) const -{ - if (!widened) - return deserializeBinaryBulk(column, istr, limit, avg_value_size_hint); - - typename ColumnVector::Container & x = typeid_cast &>(column).getData(); - size_t initial_size = x.size(); - x.resize(initial_size + limit); - - using WidestType = typename NearestFieldType::Type; - typename ColumnVector::Container y(limit); - size_t size = istr.readBig(reinterpret_cast(&y[0]), sizeof(typename ColumnVector::value_type) * limit); - size_t elem_size = size / sizeof(typename ColumnVector::value_type); - for (size_t i = 0; i < elem_size; i++) - { - x[initial_size + i] = static_cast(y[i]); - } - - x.resize(initial_size + elem_size); -} - template MutableColumnPtr DataTypeNumberBase::createColumn() const { diff --git a/dbms/src/DataTypes/DataTypeNumberBase.h b/dbms/src/DataTypes/DataTypeNumberBase.h index 6d8f3162981..4b75b954d0b 100644 --- a/dbms/src/DataTypes/DataTypeNumberBase.h +++ b/dbms/src/DataTypes/DataTypeNumberBase.h @@ -51,8 +51,6 @@ class DataTypeNumberBase : public IDataType void deserializeBinary(IColumn & column, ReadBuffer & istr) const override; void serializeBinaryBulk(const IColumn & column, WriteBuffer & ostr, size_t offset, size_t limit) const override; void deserializeBinaryBulk(IColumn & column, ReadBuffer & istr, size_t limit, double avg_value_size_hint) const override; - void serializeWidenBinaryBulk(const IColumn & column, WriteBuffer & ostr, size_t offset, size_t limit) const override; - void deserializeWidenBinaryBulk(IColumn & column, ReadBuffer & istr, size_t limit, double avg_value_size_hint) const override; MutableColumnPtr createColumn() const override; @@ -67,9 +65,6 @@ class DataTypeNumberBase : public IDataType bool haveMaximumSizeOfValue() const override { return true; } size_t getSizeOfValueInMemory() const override { return sizeof(T); } bool isCategorial() const override { return isValueRepresentedByInteger(); } - -protected: - bool widened = false; }; } // namespace DB diff --git a/dbms/src/DataTypes/DataTypesNumber.h b/dbms/src/DataTypes/DataTypesNumber.h index 7ac1294473f..7021abea3ad 100644 --- a/dbms/src/DataTypes/DataTypesNumber.h +++ b/dbms/src/DataTypes/DataTypesNumber.h @@ -35,14 +35,6 @@ class DataTypeNumber final : public DataTypeNumberBase bool isInteger() const override { return std::is_integral_v; } bool isFloatingPoint() const override { return std::is_floating_point_v; } bool canBeInsideNullable() const override { return true; } - -public: - DataTypePtr widen() const override - { - auto t = std::make_shared>(); - t->widened = true; - return t; - } }; using DataTypeUInt8 = DataTypeNumber; diff --git a/dbms/src/DataTypes/IDataType.h b/dbms/src/DataTypes/IDataType.h index 71fda0615e4..58f30600a72 100644 --- a/dbms/src/DataTypes/IDataType.h +++ b/dbms/src/DataTypes/IDataType.h @@ -179,71 +179,6 @@ class IDataType : private boost::noncopyable virtual void serializeBinaryBulk(const IColumn & column, WriteBuffer & ostr, size_t offset, size_t limit) const; virtual void deserializeBinaryBulk(IColumn & column, ReadBuffer & istr, size_t limit, double avg_value_size_hint) const; - /** Widen version for `serializeBinaryBulkWithMultipleStreams`. - */ - virtual void serializeWidenBinaryBulkWithMultipleStreams( - const IColumn & column, - const OutputStreamGetter & getter, - size_t offset, - size_t limit, - bool /*position_independent_encoding*/, - SubstreamPath & path) const - { - if (WriteBuffer * stream = getter(path)) - serializeWidenBinaryBulk(column, *stream, offset, limit); - } - - void serializeWidenBinaryBulkWithMultipleStreams( - const IColumn & column, - const OutputStreamGetter & getter, - size_t offset, - size_t limit, - bool position_independent_encoding, - SubstreamPath && path) const - { - serializeWidenBinaryBulkWithMultipleStreams(column, getter, offset, limit, position_independent_encoding, path); - } - - - /** Widen version for `deserializeBinaryBulkWithMultipleStreams`. - */ - virtual void deserializeWidenBinaryBulkWithMultipleStreams( - IColumn & column, - const InputStreamGetter & getter, - size_t limit, - double avg_value_size_hint, - bool /*position_independent_encoding*/, - SubstreamPath & path) const - { - if (ReadBuffer * stream = getter(path)) - deserializeWidenBinaryBulk(column, *stream, limit, avg_value_size_hint); - } - - void deserializeWidenBinaryBulkWithMultipleStreams( - IColumn & column, - const InputStreamGetter & getter, - size_t limit, - double avg_value_size_hint, - bool position_independent_encoding, - SubstreamPath && path) const - { - deserializeWidenBinaryBulkWithMultipleStreams(column, getter, limit, avg_value_size_hint, position_independent_encoding, path); - } - - /** Widen version for `serializeBinaryBulk`. - */ - virtual void serializeWidenBinaryBulk(const IColumn & column, WriteBuffer & ostr, size_t offset, size_t limit) const - { - serializeBinaryBulk(column, ostr, offset, limit); - } - - /** Widen version for `deserializeBinaryBulk`. - */ - virtual void deserializeWidenBinaryBulk(IColumn & column, ReadBuffer & istr, size_t limit, double avg_value_size_hint) const - { - deserializeBinaryBulk(column, istr, limit, avg_value_size_hint); - } - /** Serialization/deserialization of individual values. * * These are helper methods for implementation of various formats to input/output for user (like CSV, JSON, etc.). @@ -479,15 +414,6 @@ class IDataType : private boost::noncopyable */ virtual bool canBeInsideNullable() const { return false; }; - /** Some specific data types are required to be widened for some specific storage for whatever reason, - * i.e. to avoid data rewriting upon type change, - * TMT will intentionally store narrow type (int8/16/32) to its widest possible type (int64) of the same family, - * meanwhile behaves as its original narrow type. - * Given that most data type objects on the fly are const (DataTypePtr), this function returns a new copy of the widened type. - */ - virtual DataTypePtr widen() const { return nullptr; } - - /// Updates avg_value_size_hint for newly read column. Uses to optimize deserialization. Zero expected for first column. static void updateAvgValueSizeHint(const IColumn & column, double & avg_value_size_hint); diff --git a/dbms/src/Debug/MockComputeServerManager.cpp b/dbms/src/Debug/MockComputeServerManager.cpp index 839cd794b73..64b5c6fc7d4 100644 --- a/dbms/src/Debug/MockComputeServerManager.cpp +++ b/dbms/src/Debug/MockComputeServerManager.cpp @@ -13,6 +13,8 @@ // limitations under the License. #include #include +#include +#include #include #include diff --git a/dbms/src/Debug/MockComputeServerManager.h b/dbms/src/Debug/MockComputeServerManager.h index 6642388659f..dd622e00b70 100644 --- a/dbms/src/Debug/MockComputeServerManager.h +++ b/dbms/src/Debug/MockComputeServerManager.h @@ -15,10 +15,8 @@ #pragma once #include -#include +#include #include -#include -#include namespace DB::tests { diff --git a/dbms/src/Debug/MockExecutor/AggregationBinder.cpp b/dbms/src/Debug/MockExecutor/AggregationBinder.cpp index a39f196a389..e95346af901 100644 --- a/dbms/src/Debug/MockExecutor/AggregationBinder.cpp +++ b/dbms/src/Debug/MockExecutor/AggregationBinder.cpp @@ -12,11 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include +#include #include #include -#include +#include +#include #include namespace DB::mock @@ -25,6 +28,7 @@ bool AggregationBinder::toTiPBExecutor(tipb::Executor * tipb_executor, int32_t c { tipb_executor->set_tp(tipb::ExecType::TypeAggregation); tipb_executor->set_executor_id(name); + tipb_executor->set_fine_grained_shuffle_stream_count(fine_grained_shuffle_stream_count); auto * agg = tipb_executor->mutable_aggregation(); buildAggExpr(agg, collator_id, context); buildGroupBy(agg, collator_id, context); @@ -77,7 +81,8 @@ void AggregationBinder::toMPPSubPlan(size_t & executor_index, const DAGPropertie false, std::move(agg_exprs), std::move(gby_exprs), - false); + false, + fine_grained_shuffle_stream_count); partial_agg->children.push_back(children[0]); std::vector partition_keys; size_t agg_func_num = partial_agg->agg_exprs.size(); @@ -203,7 +208,7 @@ void AggregationBinder::buildAggFunc(tipb::Expr * agg_func, const ASTFunction * agg_func->set_aggfuncmode(tipb::AggFunctionMode::Partial1Mode); } -ExecutorBinderPtr compileAggregation(ExecutorBinderPtr input, size_t & executor_index, ASTPtr agg_funcs, ASTPtr group_by_exprs) +ExecutorBinderPtr compileAggregation(ExecutorBinderPtr input, size_t & executor_index, ASTPtr agg_funcs, ASTPtr group_by_exprs, uint64_t fine_grained_shuffle_stream_count) { std::vector agg_exprs; std::vector gby_exprs; @@ -273,7 +278,8 @@ ExecutorBinderPtr compileAggregation(ExecutorBinderPtr input, size_t & executor_ need_append_project, std::move(agg_exprs), std::move(gby_exprs), - true); + true, + fine_grained_shuffle_stream_count); aggregation->children.push_back(input); return aggregation; } diff --git a/dbms/src/Debug/MockExecutor/AggregationBinder.h b/dbms/src/Debug/MockExecutor/AggregationBinder.h index 4ece3ff7838..005549e6f0b 100644 --- a/dbms/src/Debug/MockExecutor/AggregationBinder.h +++ b/dbms/src/Debug/MockExecutor/AggregationBinder.h @@ -14,22 +14,25 @@ #pragma once -#include -#include #include +#include namespace DB::mock { +class ExchangeSenderBinder; +class ExchangeReceiverBinder; + class AggregationBinder : public ExecutorBinder { public: - AggregationBinder(size_t & index_, const DAGSchema & output_schema_, bool has_uniq_raw_res_, bool need_append_project_, ASTs && agg_exprs_, ASTs && gby_exprs_, bool is_final_mode_) + AggregationBinder(size_t & index_, const DAGSchema & output_schema_, bool has_uniq_raw_res_, bool need_append_project_, ASTs && agg_exprs_, ASTs && gby_exprs_, bool is_final_mode_, uint64_t fine_grained_shuffle_stream_count_) : ExecutorBinder(index_, "aggregation_" + std::to_string(index_), output_schema_) , has_uniq_raw_res(has_uniq_raw_res_) , need_append_project(need_append_project_) , agg_exprs(std::move(agg_exprs_)) , gby_exprs(std::move(gby_exprs_)) , is_final_mode(is_final_mode_) + , fine_grained_shuffle_stream_count(fine_grained_shuffle_stream_count_) {} bool toTiPBExecutor(tipb::Executor * tipb_executor, int32_t collator_id, const MPPInfo & mpp_info, const Context & context) override; @@ -51,6 +54,7 @@ class AggregationBinder : public ExecutorBinder std::vector gby_exprs; bool is_final_mode; DAGSchema output_schema_for_partial_agg; + uint64_t fine_grained_shuffle_stream_count; private: void buildGroupBy(tipb::Aggregation * agg, int32_t collator_id, const Context & context) const; @@ -58,6 +62,6 @@ class AggregationBinder : public ExecutorBinder void buildAggFunc(tipb::Expr * agg_func, const ASTFunction * func, int32_t collator_id) const; }; -ExecutorBinderPtr compileAggregation(ExecutorBinderPtr input, size_t & executor_index, ASTPtr agg_funcs, ASTPtr group_by_exprs); +ExecutorBinderPtr compileAggregation(ExecutorBinderPtr input, size_t & executor_index, ASTPtr agg_funcs, ASTPtr group_by_exprs, uint64_t fine_grained_shuffle_stream_count = 0); } // namespace DB::mock diff --git a/dbms/src/Debug/MockExecutor/AstToPB.cpp b/dbms/src/Debug/MockExecutor/AstToPB.cpp index 306d2c24813..fa58e2e3fc8 100644 --- a/dbms/src/Debug/MockExecutor/AstToPB.cpp +++ b/dbms/src/Debug/MockExecutor/AstToPB.cpp @@ -12,7 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace DB { diff --git a/dbms/src/Debug/MockExecutor/AstToPB.h b/dbms/src/Debug/MockExecutor/AstToPB.h index c1560c90355..2f25618c361 100644 --- a/dbms/src/Debug/MockExecutor/AstToPB.h +++ b/dbms/src/Debug/MockExecutor/AstToPB.h @@ -14,24 +14,8 @@ #pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include namespace DB { @@ -41,6 +25,11 @@ extern const int BAD_ARGUMENTS; extern const int LOGICAL_ERROR; extern const int NO_SUCH_COLUMN_IN_TABLE; } // namespace ErrorCodes + +class ASTFunction; +class ASTIdentifier; +class Context; + struct MPPCtx { Timestamp start_ts; diff --git a/dbms/src/Debug/MockExecutor/ExchangeReceiverBinder.cpp b/dbms/src/Debug/MockExecutor/ExchangeReceiverBinder.cpp index 706624856c0..21d4d649ffe 100644 --- a/dbms/src/Debug/MockExecutor/ExchangeReceiverBinder.cpp +++ b/dbms/src/Debug/MockExecutor/ExchangeReceiverBinder.cpp @@ -12,8 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include +#include #include #include +#include +#include namespace DB::mock { diff --git a/dbms/src/Debug/MockExecutor/ExchangeReceiverBinder.h b/dbms/src/Debug/MockExecutor/ExchangeReceiverBinder.h index 2885dfd895d..c2327c87861 100644 --- a/dbms/src/Debug/MockExecutor/ExchangeReceiverBinder.h +++ b/dbms/src/Debug/MockExecutor/ExchangeReceiverBinder.h @@ -14,6 +14,7 @@ #pragma once +#include #include namespace DB::mock diff --git a/dbms/src/Debug/MockExecutor/ExchangeSenderBinder.cpp b/dbms/src/Debug/MockExecutor/ExchangeSenderBinder.cpp index aaba39868e1..45abb7de9fa 100644 --- a/dbms/src/Debug/MockExecutor/ExchangeSenderBinder.cpp +++ b/dbms/src/Debug/MockExecutor/ExchangeSenderBinder.cpp @@ -12,8 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include +#include #include #include +#include +#include namespace DB::mock { diff --git a/dbms/src/Debug/MockExecutor/ExchangeSenderBinder.h b/dbms/src/Debug/MockExecutor/ExchangeSenderBinder.h index 0b8b33821cf..ed6710ac22e 100644 --- a/dbms/src/Debug/MockExecutor/ExchangeSenderBinder.h +++ b/dbms/src/Debug/MockExecutor/ExchangeSenderBinder.h @@ -14,6 +14,7 @@ #pragma once +#include #include namespace DB::mock diff --git a/dbms/src/Debug/MockExecutor/ExecutorBinder.h b/dbms/src/Debug/MockExecutor/ExecutorBinder.h index de8e3c9928c..d1a03ff96d3 100644 --- a/dbms/src/Debug/MockExecutor/ExecutorBinder.h +++ b/dbms/src/Debug/MockExecutor/ExecutorBinder.h @@ -14,10 +14,13 @@ #pragma once +#include #include -#include #include -#include +#include +#include +#include +#include namespace DB::mock @@ -25,7 +28,6 @@ namespace DB::mock class ExchangeSenderBinder; class ExchangeReceiverBinder; - // Convert CH AST to tipb::Executor // Used in integration test framework and Unit test framework. class ExecutorBinder @@ -45,7 +47,7 @@ class ExecutorBinder index_++; } - std::vector> getChildren() + std::vector> getChildren() const { return children; } diff --git a/dbms/src/Debug/MockExecutor/JoinBinder.cpp b/dbms/src/Debug/MockExecutor/JoinBinder.cpp index 92109b73f1b..e9bc36bc5d0 100644 --- a/dbms/src/Debug/MockExecutor/JoinBinder.cpp +++ b/dbms/src/Debug/MockExecutor/JoinBinder.cpp @@ -12,10 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include #include #include #include +#include +#include #include namespace DB::mock @@ -136,6 +140,7 @@ bool JoinBinder::toTiPBExecutor(tipb::Executor * tipb_executor, int32_t collator { tipb_executor->set_tp(tipb::ExecType::TypeJoin); tipb_executor->set_executor_id(name); + tipb_executor->set_fine_grained_shuffle_stream_count(fine_grained_shuffle_stream_count); tipb::Join * join = tipb_executor->mutable_join(); @@ -284,14 +289,15 @@ ExecutorBinderPtr compileJoin(size_t & executor_index, const ASTs & left_conds, const ASTs & right_conds, const ASTs & other_conds, - const ASTs & other_eq_conds_from_in) + const ASTs & other_eq_conds_from_in, + uint64_t fine_grained_shuffle_stream_count) { DAGSchema output_schema; buildLeftSideJoinSchema(output_schema, left->output_schema, tp); buildRightSideJoinSchema(output_schema, right->output_schema, tp); - auto join = std::make_shared(executor_index, output_schema, tp, join_cols, left_conds, right_conds, other_conds, other_eq_conds_from_in); + auto join = std::make_shared(executor_index, output_schema, tp, join_cols, left_conds, right_conds, other_conds, other_eq_conds_from_in, fine_grained_shuffle_stream_count); join->children.push_back(left); join->children.push_back(right); diff --git a/dbms/src/Debug/MockExecutor/JoinBinder.h b/dbms/src/Debug/MockExecutor/JoinBinder.h index 5ab1fb83f4b..cbdcd9d25b9 100644 --- a/dbms/src/Debug/MockExecutor/JoinBinder.h +++ b/dbms/src/Debug/MockExecutor/JoinBinder.h @@ -14,16 +14,16 @@ #pragma once -#include -#include #include namespace DB::mock { +class ExchangeSenderBinder; +class ExchangeReceiverBinder; class JoinBinder : public ExecutorBinder { public: - JoinBinder(size_t & index_, const DAGSchema & output_schema_, tipb::JoinType tp_, const ASTs & join_cols_, const ASTs & l_conds, const ASTs & r_conds, const ASTs & o_conds, const ASTs & o_eq_conds) + JoinBinder(size_t & index_, const DAGSchema & output_schema_, tipb::JoinType tp_, const ASTs & join_cols_, const ASTs & l_conds, const ASTs & r_conds, const ASTs & o_conds, const ASTs & o_eq_conds, uint64_t fine_grained_shuffle_stream_count_) : ExecutorBinder(index_, "Join_" + std::to_string(index_), output_schema_) , tp(tp_) , join_cols(join_cols_) @@ -31,6 +31,7 @@ class JoinBinder : public ExecutorBinder , right_conds(r_conds) , other_conds(o_conds) , other_eq_conds_from_in(o_eq_conds) + , fine_grained_shuffle_stream_count(fine_grained_shuffle_stream_count_) { if (!(join_cols.size() + left_conds.size() + right_conds.size() + other_conds.size() + other_eq_conds_from_in.size())) throw Exception("No join condition found."); @@ -57,9 +58,10 @@ class JoinBinder : public ExecutorBinder const ASTs right_conds{}; const ASTs other_conds{}; const ASTs other_eq_conds_from_in{}; + uint64_t fine_grained_shuffle_stream_count; }; // compileJoin constructs a mocked Join executor node, note that all conditional expression params can be default -ExecutorBinderPtr compileJoin(size_t & executor_index, ExecutorBinderPtr left, ExecutorBinderPtr right, tipb::JoinType tp, const ASTs & join_cols, const ASTs & left_conds = {}, const ASTs & right_conds = {}, const ASTs & other_conds = {}, const ASTs & other_eq_conds_from_in = {}); +ExecutorBinderPtr compileJoin(size_t & executor_index, ExecutorBinderPtr left, ExecutorBinderPtr right, tipb::JoinType tp, const ASTs & join_cols, const ASTs & left_conds = {}, const ASTs & right_conds = {}, const ASTs & other_conds = {}, const ASTs & other_eq_conds_from_in = {}, uint64_t fine_grained_shuffle_stream_count = 0); /// Note: this api is only used by legacy test framework for compatibility purpose, which will be depracated soon, diff --git a/dbms/src/Debug/MockExecutor/LimitBinder.cpp b/dbms/src/Debug/MockExecutor/LimitBinder.cpp index c0a9bf17a82..de90b96f252 100644 --- a/dbms/src/Debug/MockExecutor/LimitBinder.cpp +++ b/dbms/src/Debug/MockExecutor/LimitBinder.cpp @@ -14,6 +14,7 @@ #include #include +#include namespace DB::mock { diff --git a/dbms/src/Debug/MockExecutor/ProjectBinder.cpp b/dbms/src/Debug/MockExecutor/ProjectBinder.cpp index ebe8e5d8bde..50f0646c864 100644 --- a/dbms/src/Debug/MockExecutor/ProjectBinder.cpp +++ b/dbms/src/Debug/MockExecutor/ProjectBinder.cpp @@ -12,9 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include #include +#include #include +#include namespace DB::mock { diff --git a/dbms/src/Debug/MockExecutor/SelectionBinder.cpp b/dbms/src/Debug/MockExecutor/SelectionBinder.cpp index cea52b56922..c3171fa5e2c 100644 --- a/dbms/src/Debug/MockExecutor/SelectionBinder.cpp +++ b/dbms/src/Debug/MockExecutor/SelectionBinder.cpp @@ -11,7 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include +#include #include namespace DB::mock diff --git a/dbms/src/Debug/MockExecutor/SelectionBinder.h b/dbms/src/Debug/MockExecutor/SelectionBinder.h index d4270ed5fac..b5e1c2000f3 100644 --- a/dbms/src/Debug/MockExecutor/SelectionBinder.h +++ b/dbms/src/Debug/MockExecutor/SelectionBinder.h @@ -14,7 +14,6 @@ #pragma once -#include #include namespace DB::mock diff --git a/dbms/src/Debug/MockExecutor/SortBinder.cpp b/dbms/src/Debug/MockExecutor/SortBinder.cpp index 80265448824..1af2820c71e 100644 --- a/dbms/src/Debug/MockExecutor/SortBinder.cpp +++ b/dbms/src/Debug/MockExecutor/SortBinder.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include diff --git a/dbms/src/Debug/MockExecutor/TableScanBinder.cpp b/dbms/src/Debug/MockExecutor/TableScanBinder.cpp index e35a14e4269..27f399f6d40 100644 --- a/dbms/src/Debug/MockExecutor/TableScanBinder.cpp +++ b/dbms/src/Debug/MockExecutor/TableScanBinder.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include diff --git a/dbms/src/Debug/MockExecutor/TopNBinder.cpp b/dbms/src/Debug/MockExecutor/TopNBinder.cpp index f8d7dd5f006..aee74b9300b 100644 --- a/dbms/src/Debug/MockExecutor/TopNBinder.cpp +++ b/dbms/src/Debug/MockExecutor/TopNBinder.cpp @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include +#include #include namespace DB::mock diff --git a/dbms/src/Debug/MockExecutor/WindowBinder.cpp b/dbms/src/Debug/MockExecutor/WindowBinder.cpp index 8da8ae5d8ef..0642300cecb 100644 --- a/dbms/src/Debug/MockExecutor/WindowBinder.cpp +++ b/dbms/src/Debug/MockExecutor/WindowBinder.cpp @@ -12,8 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include +#include #include +#include namespace DB::mock { diff --git a/dbms/src/Debug/MockExecutor/WindowBinder.h b/dbms/src/Debug/MockExecutor/WindowBinder.h index 443506baa33..b9745d3358b 100644 --- a/dbms/src/Debug/MockExecutor/WindowBinder.h +++ b/dbms/src/Debug/MockExecutor/WindowBinder.h @@ -27,7 +27,6 @@ struct MockWindowFrame std::optional end; // TODO: support calcFuncs }; - using ASTPartitionByElement = ASTOrderByElement; class WindowBinder : public ExecutorBinder diff --git a/dbms/src/Debug/MockStorage.cpp b/dbms/src/Debug/MockStorage.cpp index dbcf38c831b..7a19da7085b 100644 --- a/dbms/src/Debug/MockStorage.cpp +++ b/dbms/src/Debug/MockStorage.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #include +#include namespace DB::tests { @@ -22,8 +23,11 @@ void MockStorage::addTableSchema(const String & name, const MockColumnInfoVec & addTableInfo(name, columnInfos); } -void MockStorage::addTableData(const String & name, const ColumnsWithTypeAndName & columns) +void MockStorage::addTableData(const String & name, ColumnsWithTypeAndName & columns) { + for (size_t i = 0; i < columns.size(); ++i) + columns[i].column_id = i; + table_columns[getTableId(name)] = columns; } @@ -123,11 +127,12 @@ CutColumnInfo getCutColumnInfo(size_t rows, Int64 partition_id, Int64 partition_ return {start, cur_rows}; } -ColumnsWithTypeAndName MockStorage::getColumnsForMPPTableScan(Int64 table_id, Int64 partition_id, Int64 partition_num) +ColumnsWithTypeAndName MockStorage::getColumnsForMPPTableScan(const TiDBTableScan & table_scan, Int64 partition_id, Int64 partition_num) { + auto table_id = table_scan.getLogicalTableID(); if (tableExists(table_id)) { - auto columns_with_type_and_name = table_columns[table_id]; + auto columns_with_type_and_name = table_columns[table_scan.getLogicalTableID()]; size_t rows = 0; for (const auto & col : columns_with_type_and_name) { @@ -141,11 +146,23 @@ ColumnsWithTypeAndName MockStorage::getColumnsForMPPTableScan(Int64 table_id, In ColumnsWithTypeAndName res; for (const auto & column_with_type_and_name : columns_with_type_and_name) { - res.push_back( - ColumnWithTypeAndName( - column_with_type_and_name.column->cut(cut_info.first, cut_info.second), - column_with_type_and_name.type, - column_with_type_and_name.name)); + bool contains = false; + for (const auto & column : table_scan.getColumns()) + { + if (column.id == column_with_type_and_name.column_id) + { + contains = true; + break; + } + } + if (contains) + { + res.push_back( + ColumnWithTypeAndName( + column_with_type_and_name.column->cut(cut_info.first, cut_info.second), + column_with_type_and_name.type, + column_with_type_and_name.name)); + } } return res; } diff --git a/dbms/src/Debug/MockStorage.h b/dbms/src/Debug/MockStorage.h index 46e8331602f..ff5ff0627b3 100644 --- a/dbms/src/Debug/MockStorage.h +++ b/dbms/src/Debug/MockStorage.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once #include +#include #include #include @@ -47,7 +48,7 @@ class MockStorage /// for table scan void addTableSchema(const String & name, const MockColumnInfoVec & columnInfos); - void addTableData(const String & name, const ColumnsWithTypeAndName & columns); + void addTableData(const String & name, ColumnsWithTypeAndName & columns); Int64 getTableId(const String & name); @@ -72,7 +73,7 @@ class MockStorage MockColumnInfoVec getExchangeSchema(const String & exchange_name); /// for MPP Tasks, it will split data by partition num, then each MPP service will have a subset of mock data. - ColumnsWithTypeAndName getColumnsForMPPTableScan(Int64 table_id, Int64 partition_id, Int64 partition_num); + ColumnsWithTypeAndName getColumnsForMPPTableScan(const TiDBTableScan & table_scan, Int64 partition_id, Int64 partition_num); TableInfo getTableInfo(const String & name); diff --git a/dbms/src/Debug/dbgFuncCoprocessor.cpp b/dbms/src/Debug/dbgFuncCoprocessor.cpp index 07ee8703b92..112f43b568b 100644 --- a/dbms/src/Debug/dbgFuncCoprocessor.cpp +++ b/dbms/src/Debug/dbgFuncCoprocessor.cpp @@ -15,7 +15,10 @@ #include #include #include +#include +#include #include +#include namespace DB { diff --git a/dbms/src/Debug/dbgFuncCoprocessor.h b/dbms/src/Debug/dbgFuncCoprocessor.h index a296e93d410..9a21842fa50 100644 --- a/dbms/src/Debug/dbgFuncCoprocessor.h +++ b/dbms/src/Debug/dbgFuncCoprocessor.h @@ -15,7 +15,6 @@ #pragma once #include -#include namespace DB { class Context; diff --git a/dbms/src/Debug/dbgFuncCoprocessorUtils.cpp b/dbms/src/Debug/dbgFuncCoprocessorUtils.cpp index b1dd70feba7..e89163c2c1d 100644 --- a/dbms/src/Debug/dbgFuncCoprocessorUtils.cpp +++ b/dbms/src/Debug/dbgFuncCoprocessorUtils.cpp @@ -12,8 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include +#include +#include #include +#include +#include +#include #include +#include +#include +#include +#include +#include +#include namespace DB { diff --git a/dbms/src/Debug/dbgFuncCoprocessorUtils.h b/dbms/src/Debug/dbgFuncCoprocessorUtils.h index 7d9ca5a1075..0f2c3d85533 100644 --- a/dbms/src/Debug/dbgFuncCoprocessorUtils.h +++ b/dbms/src/Debug/dbgFuncCoprocessorUtils.h @@ -13,19 +13,10 @@ // limitations under the License. #pragma once -#include -#include -#include -#include -#include -#include +#include #include -#include -#include #include -#include -#include -#include +#include namespace DB { @@ -33,6 +24,10 @@ namespace ErrorCodes { extern const int BAD_ARGUMENTS; } +class Context; +struct DAGProperties; +class IBlockInputStream; +using BlockInputStreamPtr = std::shared_ptr; std::unique_ptr getCodec(tipb::EncodeType encode_type); DAGSchema getSelectSchema(Context & context); diff --git a/dbms/src/Debug/dbgNaturalDag.h b/dbms/src/Debug/dbgNaturalDag.h index f7c1d850ebe..67c7dca288e 100644 --- a/dbms/src/Debug/dbgNaturalDag.h +++ b/dbms/src/Debug/dbgNaturalDag.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include diff --git a/dbms/src/Debug/dbgQueryCompiler.cpp b/dbms/src/Debug/dbgQueryCompiler.cpp index 2562e6b2efc..f9e58b1a424 100644 --- a/dbms/src/Debug/dbgQueryCompiler.cpp +++ b/dbms/src/Debug/dbgQueryCompiler.cpp @@ -12,7 +12,34 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace DB { diff --git a/dbms/src/Debug/dbgQueryCompiler.h b/dbms/src/Debug/dbgQueryCompiler.h index 748b14d41e8..87397ab0728 100644 --- a/dbms/src/Debug/dbgQueryCompiler.h +++ b/dbms/src/Debug/dbgQueryCompiler.h @@ -16,32 +16,10 @@ #include #include -#include #include -#include -#include #include -#include -#include -#include -#include -#include -#include -#include -#include #include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include #include namespace DB @@ -49,6 +27,7 @@ namespace DB using MakeResOutputStream = std::function; using ExecutorBinderPtr = mock::ExecutorBinderPtr; using TableInfo = TiDB::TableInfo; +struct ASTTablesInSelectQueryElement; enum class QueryTaskType { diff --git a/dbms/src/Debug/dbgQueryExecutor.cpp b/dbms/src/Debug/dbgQueryExecutor.cpp index be7ee9b9ca6..359aa833f25 100644 --- a/dbms/src/Debug/dbgQueryExecutor.cpp +++ b/dbms/src/Debug/dbgQueryExecutor.cpp @@ -12,12 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include - +#include +#include +#include +#include namespace DB { +using TiFlashTestEnv = tests::TiFlashTestEnv; + void setTipbRegionInfo(coprocessor::RegionInfo * tipb_region_info, const std::pair & region, TableID table_id) { tipb_region_info->set_region_id(region.first); diff --git a/dbms/src/Debug/dbgQueryExecutor.h b/dbms/src/Debug/dbgQueryExecutor.h index 0b3c639a20c..aa308ada9fc 100644 --- a/dbms/src/Debug/dbgQueryExecutor.h +++ b/dbms/src/Debug/dbgQueryExecutor.h @@ -18,7 +18,6 @@ namespace DB { using MockServerConfig = tests::MockServerConfig; -using TiFlashTestEnv = tests::TiFlashTestEnv; BlockInputStreamPtr executeQuery(Context & context, RegionID region_id, const DAGProperties & properties, QueryTasks & query_tasks, MakeResOutputStream & func_wrap_output_stream); BlockInputStreamPtr executeMPPQuery(Context & context, const DAGProperties & properties, QueryTasks & query_tasks); diff --git a/dbms/src/Flash/Coprocessor/DAGContext.h b/dbms/src/Flash/Coprocessor/DAGContext.h index aaf218ba24e..ce8b93e92d2 100644 --- a/dbms/src/Flash/Coprocessor/DAGContext.h +++ b/dbms/src/Flash/Coprocessor/DAGContext.h @@ -193,6 +193,7 @@ class DAGContext , dummy_query_string(dag_request->DebugString()) , dummy_ast(makeDummyQuery()) , initialize_concurrency(concurrency) + , collect_execution_summaries(dag_request->has_collect_execution_summaries() && dag_request->collect_execution_summaries()) , is_mpp_task(true) , is_root_mpp_task(false) , log(Logger::get(log_identifier)) diff --git a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp index 2902d66b57a..34fd53c3455 100644 --- a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp +++ b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp @@ -176,7 +176,17 @@ void DAGQueryBlockInterpreter::handleMockTableScan(const TiDBTableScan & table_s } else { - auto [names_and_types, mock_table_scan_streams] = mockSourceStream(context, max_streams, log, table_scan.getTableScanExecutorID(), table_scan.getLogicalTableID()); + NamesAndTypes names_and_types; + std::vector> mock_table_scan_streams; + if (context.isMPPTest()) + { + std::tie(names_and_types, mock_table_scan_streams) = mockSourceStreamForMpp(context, max_streams, log, table_scan); + } + else + { + std::tie(names_and_types, mock_table_scan_streams) = mockSourceStream(context, max_streams, log, table_scan.getTableScanExecutorID(), table_scan.getLogicalTableID()); + } + analyzer = std::make_unique(std::move(names_and_types), context); pipeline.streams.insert(pipeline.streams.end(), mock_table_scan_streams.begin(), mock_table_scan_streams.end()); } diff --git a/dbms/src/Flash/Coprocessor/ExecutionSummary.cpp b/dbms/src/Flash/Coprocessor/ExecutionSummary.cpp index a62693f1aec..818d0edfbea 100644 --- a/dbms/src/Flash/Coprocessor/ExecutionSummary.cpp +++ b/dbms/src/Flash/Coprocessor/ExecutionSummary.cpp @@ -16,24 +16,30 @@ namespace DB { +void ExecutionSummary::merge(const ExecutionSummary & other) +{ + time_processed_ns = std::max(time_processed_ns, other.time_processed_ns); + num_produced_rows += other.num_produced_rows; + num_iterations += other.num_iterations; + concurrency += other.concurrency; + scan_context->merge(*other.scan_context); +} + +void ExecutionSummary::merge(const tipb::ExecutorExecutionSummary & other) +{ + time_processed_ns = std::max(time_processed_ns, other.time_processed_ns()); + num_produced_rows += other.num_produced_rows(); + num_iterations += other.num_iterations(); + concurrency += other.concurrency(); + scan_context->merge(other.tiflash_scan_context()); +} -void ExecutionSummary::merge(const ExecutionSummary & other, bool streaming_call) +void ExecutionSummary::init(const tipb::ExecutorExecutionSummary & other) { - if (streaming_call) - { - time_processed_ns = std::max(time_processed_ns, other.time_processed_ns); - num_produced_rows = std::max(num_produced_rows, other.num_produced_rows); - num_iterations = std::max(num_iterations, other.num_iterations); - concurrency = std::max(concurrency, other.concurrency); - scan_context->merge(*other.scan_context); - } - else - { - time_processed_ns = std::max(time_processed_ns, other.time_processed_ns); - num_produced_rows += other.num_produced_rows; - num_iterations += other.num_iterations; - concurrency += other.concurrency; - scan_context->merge(*other.scan_context); - } + time_processed_ns = other.time_processed_ns(); + num_produced_rows = other.num_produced_rows(); + num_iterations = other.num_iterations(); + concurrency = other.concurrency(); + scan_context->deserialize(other.tiflash_scan_context()); } } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/ExecutionSummary.h b/dbms/src/Flash/Coprocessor/ExecutionSummary.h index eafeaeed292..5a8ce579a6e 100644 --- a/dbms/src/Flash/Coprocessor/ExecutionSummary.h +++ b/dbms/src/Flash/Coprocessor/ExecutionSummary.h @@ -16,6 +16,7 @@ #include #include +#include #include @@ -29,11 +30,13 @@ struct ExecutionSummary UInt64 num_iterations = 0; UInt64 concurrency = 0; - std::unique_ptr scan_context = std::make_unique(); + DM::ScanContextPtr scan_context = std::make_shared(); ExecutionSummary() = default; - void merge(const ExecutionSummary & other, bool streaming_call); + void merge(const ExecutionSummary & other); + void merge(const tipb::ExecutorExecutionSummary & other); + void init(const tipb::ExecutorExecutionSummary & other); }; } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/ExecutionSummaryCollector.cpp b/dbms/src/Flash/Coprocessor/ExecutionSummaryCollector.cpp index c21c839760c..86122e400b2 100644 --- a/dbms/src/Flash/Coprocessor/ExecutionSummaryCollector.cpp +++ b/dbms/src/Flash/Coprocessor/ExecutionSummaryCollector.cpp @@ -14,14 +14,31 @@ #include #include +#include #include -#include -#include - -#include +#include namespace DB { +namespace +{ +RemoteExecutionSummary getRemoteExecutionSummariesFromExchange(DAGContext & dag_context) +{ + RemoteExecutionSummary exchange_execution_summary; + for (const auto & map_entry : dag_context.getInBoundIOInputStreamsMap()) + { + for (const auto & stream_ptr : map_entry.second) + { + if (auto * exchange_receiver_stream_ptr = dynamic_cast(stream_ptr.get()); exchange_receiver_stream_ptr) + { + exchange_execution_summary.merge(exchange_receiver_stream_ptr->getRemoteExecutionSummary()); + } + } + } + return exchange_execution_summary; +} +} // namespace + void ExecutionSummaryCollector::fillTiExecutionSummary( tipb::ExecutorExecutionSummary * execution_summary, ExecutionSummary & current, @@ -37,29 +54,6 @@ void ExecutionSummaryCollector::fillTiExecutionSummary( execution_summary->set_executor_id(executor_id); } -template -void mergeRemoteExecuteSummaries( - RemoteBlockInputStream * input_stream, - std::unordered_map> & execution_summaries) -{ - size_t source_num = input_stream->getSourceNum(); - for (size_t s_index = 0; s_index < source_num; ++s_index) - { - auto remote_execution_summaries = input_stream->getRemoteExecutionSummaries(s_index); - if (remote_execution_summaries == nullptr) - continue; - bool is_streaming_call = input_stream->isStreamingCall(); - for (auto & p : *remote_execution_summaries) - { - if (execution_summaries[p.first].size() < source_num) - { - execution_summaries[p.first].resize(source_num); - } - execution_summaries[p.first][s_index].merge(p.second, is_streaming_call); - } - } -} - tipb::SelectResponse ExecutionSummaryCollector::genExecutionSummaryResponse() { tipb::SelectResponse response; @@ -67,89 +61,67 @@ tipb::SelectResponse ExecutionSummaryCollector::genExecutionSummaryResponse() return response; } -void ExecutionSummaryCollector::addExecuteSummaries(tipb::SelectResponse & response) +void ExecutionSummaryCollector::fillLocalExecutionSummary( + tipb::SelectResponse & response, + const String & executor_id, + const BlockInputStreams & streams, + const std::unordered_map & scan_context_map) const { - if (!dag_context.collect_execution_summaries) - return; - /// get executionSummary info from remote input streams - std::unordered_map> merged_remote_execution_summaries; - for (const auto & map_entry : dag_context.getInBoundIOInputStreamsMap()) + ExecutionSummary current; + /// part 1: local execution info + // get execution info from streams + for (const auto & stream_ptr : streams) { - for (const auto & stream_ptr : map_entry.second) + if (auto * p_stream = dynamic_cast(stream_ptr.get())) { - if (auto * exchange_receiver_stream_ptr = dynamic_cast(stream_ptr.get())) - { - mergeRemoteExecuteSummaries(exchange_receiver_stream_ptr, merged_remote_execution_summaries); - } - else if (auto * cop_stream_ptr = dynamic_cast(stream_ptr.get())) - { - mergeRemoteExecuteSummaries(cop_stream_ptr, merged_remote_execution_summaries); - } - else - { - /// local read input stream - } + current.time_processed_ns = std::max(current.time_processed_ns, p_stream->getProfileInfo().execution_time); + current.num_produced_rows += p_stream->getProfileInfo().rows; + current.num_iterations += p_stream->getProfileInfo().blocks; } + ++current.concurrency; } - - auto fill_execution_summary = [&](const String & executor_id, const BlockInputStreams & streams, const std::unordered_map & scan_context_map) { - ExecutionSummary current; - /// part 1: local execution info - // get execution info from streams - for (const auto & stream_ptr : streams) - { - if (auto * p_stream = dynamic_cast(stream_ptr.get())) - { - current.time_processed_ns = std::max(current.time_processed_ns, p_stream->getProfileInfo().execution_time); - current.num_produced_rows += p_stream->getProfileInfo().rows; - current.num_iterations += p_stream->getProfileInfo().blocks; - } - current.concurrency++; - } - // get execution info from scan_context - if (const auto & iter = scan_context_map.find(executor_id); iter != scan_context_map.end()) - { - current.scan_context->merge(*(iter->second)); - } - - /// part 2: remote execution info - if (merged_remote_execution_summaries.find(executor_id) != merged_remote_execution_summaries.end()) - { - for (auto & remote : merged_remote_execution_summaries[executor_id]) - current.merge(remote, false); - } - /// part 3: for join need to add the build time - /// In TiFlash, a hash join's build side is finished before probe side starts, - /// so the join probe side's running time does not include hash table's build time, - /// when construct ExecSummaries, we need add the build cost to probe executor - auto all_join_id_it = dag_context.getExecutorIdToJoinIdMap().find(executor_id); - if (all_join_id_it != dag_context.getExecutorIdToJoinIdMap().end()) + // get execution info from scan_context + if (const auto & iter = scan_context_map.find(executor_id); iter != scan_context_map.end()) + { + current.scan_context->merge(*(iter->second)); + } + /// part 2: for join need to add the build time + /// In TiFlash, a hash join's build side is finished before probe side starts, + /// so the join probe side's running time does not include hash table's build time, + /// when construct ExecSummaries, we need add the build cost to probe executor + auto all_join_id_it = dag_context.getExecutorIdToJoinIdMap().find(executor_id); + if (all_join_id_it != dag_context.getExecutorIdToJoinIdMap().end()) + { + for (const auto & join_executor_id : all_join_id_it->second) { - for (const auto & join_executor_id : all_join_id_it->second) + auto it = dag_context.getJoinExecuteInfoMap().find(join_executor_id); + if (it != dag_context.getJoinExecuteInfoMap().end()) { - auto it = dag_context.getJoinExecuteInfoMap().find(join_executor_id); - if (it != dag_context.getJoinExecuteInfoMap().end()) + UInt64 process_time_for_build = 0; + for (const auto & join_build_stream : it->second.join_build_streams) { - UInt64 process_time_for_build = 0; - for (const auto & join_build_stream : it->second.join_build_streams) - { - if (auto * p_stream = dynamic_cast(join_build_stream.get()); p_stream) - process_time_for_build = std::max(process_time_for_build, p_stream->getProfileInfo().execution_time); - } - current.time_processed_ns += process_time_for_build; + if (auto * p_stream = dynamic_cast(join_build_stream.get()); p_stream) + process_time_for_build = std::max(process_time_for_build, p_stream->getProfileInfo().execution_time); } + current.time_processed_ns += process_time_for_build; } } + } - current.time_processed_ns += dag_context.compile_time_ns; - fillTiExecutionSummary(response.add_execution_summaries(), current, executor_id); - }; + current.time_processed_ns += dag_context.compile_time_ns; + fillTiExecutionSummary(response.add_execution_summaries(), current, executor_id); +} - /// add execution_summary for local executor +void ExecutionSummaryCollector::addExecuteSummaries(tipb::SelectResponse & response) +{ + if (!dag_context.collect_execution_summaries) + return; + + /// fill execution_summary for local executor if (dag_context.return_executor_id) { for (auto & p : dag_context.getProfileStreamsMap()) - fill_execution_summary(p.first, p.second, dag_context.scan_context_map); + fillLocalExecutionSummary(response, p.first, p.second, dag_context.scan_context_map); } else { @@ -159,19 +131,16 @@ void ExecutionSummaryCollector::addExecuteSummaries(tipb::SelectResponse & respo { auto it = profile_streams_map.find(executor_id); assert(it != profile_streams_map.end()); - fill_execution_summary(executor_id, it->second, dag_context.scan_context_map); + fillLocalExecutionSummary(response, executor_id, it->second, dag_context.scan_context_map); } } - for (auto & p : merged_remote_execution_summaries) + // TODO support cop remote read and disaggregated mode. + auto exchange_execution_summary = getRemoteExecutionSummariesFromExchange(dag_context); + // fill execution_summary to reponse for remote executor received by exchange. + for (auto & p : exchange_execution_summary.execution_summaries) { - if (local_executors.find(p.first) == local_executors.end()) - { - ExecutionSummary merged; - for (auto & remote : p.second) - merged.merge(remote, false); - fillTiExecutionSummary(response.add_execution_summaries(), merged, p.first); - } + fillTiExecutionSummary(response.add_execution_summaries(), p.second, p.first); } } } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/ExecutionSummaryCollector.h b/dbms/src/Flash/Coprocessor/ExecutionSummaryCollector.h index dedd488d125..dc5a64e723b 100644 --- a/dbms/src/Flash/Coprocessor/ExecutionSummaryCollector.h +++ b/dbms/src/Flash/Coprocessor/ExecutionSummaryCollector.h @@ -14,23 +14,21 @@ #pragma once -#include +#include #include +#include namespace DB { +class DAGContext; + class ExecutionSummaryCollector { public: explicit ExecutionSummaryCollector( DAGContext & dag_context_) : dag_context(dag_context_) - { - for (auto & p : dag_context.getProfileStreamsMap()) - { - local_executors.insert(p.first); - } - } + {} void addExecuteSummaries(tipb::SelectResponse & response); @@ -42,8 +40,13 @@ class ExecutionSummaryCollector ExecutionSummary & current, const String & executor_id) const; + void fillLocalExecutionSummary( + tipb::SelectResponse & response, + const String & executor_id, + const BlockInputStreams & streams, + const std::unordered_map & scan_context_map) const; + private: DAGContext & dag_context; - std::unordered_set local_executors; }; } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/MockSourceStream.cpp b/dbms/src/Flash/Coprocessor/MockSourceStream.cpp new file mode 100644 index 00000000000..c8e662adc32 --- /dev/null +++ b/dbms/src/Flash/Coprocessor/MockSourceStream.cpp @@ -0,0 +1,24 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +namespace DB +{ +std::pair>> mockSourceStreamForMpp(Context & context, size_t max_streams, DB::LoggerPtr log, const TiDBTableScan & table_scan) +{ + ColumnsWithTypeAndName columns_with_type_and_name = context.mockStorage().getColumnsForMPPTableScan(table_scan, context.mockMPPServerInfo().partition_id, context.mockMPPServerInfo().partition_num); + return cutStreams(context, columns_with_type_and_name, max_streams, log); +} +} // namespace DB diff --git a/dbms/src/Flash/Coprocessor/MockSourceStream.h b/dbms/src/Flash/Coprocessor/MockSourceStream.h index 7cb0ffc95e7..c84d37d2a06 100644 --- a/dbms/src/Flash/Coprocessor/MockSourceStream.h +++ b/dbms/src/Flash/Coprocessor/MockSourceStream.h @@ -18,25 +18,19 @@ #include #include #include +#include #include +#include + namespace DB { - template -std::pair>> mockSourceStream(Context & context, size_t max_streams, DB::LoggerPtr log, String executor_id, Int64 table_id = 0) +std::pair>> cutStreams(Context & context, ColumnsWithTypeAndName & columns_with_type_and_name, size_t max_streams, DB::LoggerPtr log) { - ColumnsWithTypeAndName columns_with_type_and_name; NamesAndTypes names_and_types; size_t rows = 0; std::vector> mock_source_streams; - if constexpr (std::is_same_v) - columns_with_type_and_name = context.mockStorage().getExchangeColumns(executor_id); - else if (context.isMPPTest()) - columns_with_type_and_name = context.mockStorage().getColumnsForMPPTableScan(table_id, context.mockMPPServerInfo().partition_id, context.mockMPPServerInfo().partition_num); - else - columns_with_type_and_name = context.mockStorage().getColumns(table_id); - for (const auto & col : columns_with_type_and_name) { if (rows == 0) @@ -68,4 +62,18 @@ std::pair>> mockSourceStr RUNTIME_ASSERT(start == rows, log, "mock source streams' total size must same as user input"); return {names_and_types, mock_source_streams}; } + +std::pair>> mockSourceStreamForMpp(Context & context, size_t max_streams, DB::LoggerPtr log, const TiDBTableScan & table_scan); + +template +std::pair>> mockSourceStream(Context & context, size_t max_streams, DB::LoggerPtr log, String executor_id, Int64 table_id = 0) +{ + ColumnsWithTypeAndName columns_with_type_and_name; + if constexpr (std::is_same_v) + columns_with_type_and_name = context.mockStorage().getExchangeColumns(executor_id); + else + columns_with_type_and_name = context.mockStorage().getColumns(table_id); + + return cutStreams(context, columns_with_type_and_name, max_streams, log); +} } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/RemoteExecutionSummary.cpp b/dbms/src/Flash/Coprocessor/RemoteExecutionSummary.cpp new file mode 100644 index 00000000000..fc88afcf700 --- /dev/null +++ b/dbms/src/Flash/Coprocessor/RemoteExecutionSummary.cpp @@ -0,0 +1,59 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +namespace DB +{ +void RemoteExecutionSummary::merge(const RemoteExecutionSummary & other) +{ + for (const auto & p : other.execution_summaries) + { + const auto & executor_id = p.first; + auto it = execution_summaries.find(executor_id); + if (unlikely(it == execution_summaries.end())) + { + execution_summaries[executor_id] = p.second; + } + else + { + it->second.merge(p.second); + } + } +} + +void RemoteExecutionSummary::add(tipb::SelectResponse & resp) +{ + if (unlikely(resp.execution_summaries_size() == 0)) + return; + + for (const auto & execution_summary : resp.execution_summaries()) + { + if (likely(execution_summary.has_executor_id())) + { + const auto & executor_id = execution_summary.executor_id(); + auto it = execution_summaries.find(executor_id); + if (unlikely(it == execution_summaries.end())) + { + execution_summaries[executor_id].init(execution_summary); + } + else + { + it->second.merge(execution_summary); + } + } + } +} +} // namespace DB diff --git a/dbms/src/Flash/Coprocessor/RemoteExecutionSummary.h b/dbms/src/Flash/Coprocessor/RemoteExecutionSummary.h new file mode 100644 index 00000000000..dd2a9d0b5bf --- /dev/null +++ b/dbms/src/Flash/Coprocessor/RemoteExecutionSummary.h @@ -0,0 +1,33 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include + +namespace DB +{ +struct RemoteExecutionSummary +{ + void merge(const RemoteExecutionSummary & other); + + void add(tipb::SelectResponse & resp); + + // + std::unordered_map execution_summaries; +}; +} // namespace DB diff --git a/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp b/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp index 2434feeba26..0162b940ce4 100644 --- a/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp +++ b/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp @@ -398,11 +398,10 @@ class TestTiRemoteBlockInputStream : public testing::Test { assert(receiver_stream); /// Check Execution Summary - const auto * summary = receiver_stream->getRemoteExecutionSummaries(0); - ASSERT_TRUE(summary != nullptr); - ASSERT_EQ(summary->size(), 1); - ASSERT_EQ(summary->begin()->first, "Executor_0"); - ASSERT_TRUE(equalSummaries(writer->mockExecutionSummary(), summary->begin()->second)); + const auto & summary = receiver_stream->getRemoteExecutionSummary(); + ASSERT_EQ(summary.execution_summaries.size(), 1); + ASSERT_EQ(summary.execution_summaries.begin()->first, "Executor_0"); + ASSERT_TRUE(equalSummaries(writer->mockExecutionSummary(), summary.execution_summaries.begin()->second)); /// Check Connection Info auto infos = receiver_stream->getConnectionProfileInfos(); diff --git a/dbms/src/Flash/Planner/plans/PhysicalMockTableScan.cpp b/dbms/src/Flash/Planner/plans/PhysicalMockTableScan.cpp index c3019f5ed8b..23833b9af0b 100644 --- a/dbms/src/Flash/Planner/plans/PhysicalMockTableScan.cpp +++ b/dbms/src/Flash/Planner/plans/PhysicalMockTableScan.cpp @@ -50,7 +50,16 @@ std::pair mockSchemaAndStreams( else { /// build from user input blocks. - auto [names_and_types, mock_table_scan_streams] = mockSourceStream(context, max_streams, log, executor_id, table_scan.getLogicalTableID()); + NamesAndTypes names_and_types; + std::vector> mock_table_scan_streams; + if (context.isMPPTest()) + { + std::tie(names_and_types, mock_table_scan_streams) = mockSourceStreamForMpp(context, max_streams, log, table_scan); + } + else + { + std::tie(names_and_types, mock_table_scan_streams) = mockSourceStream(context, max_streams, log, executor_id, table_scan.getLogicalTableID()); + } schema = std::move(names_and_types); mock_streams.insert(mock_streams.end(), mock_table_scan_streams.begin(), mock_table_scan_streams.end()); } diff --git a/dbms/src/Flash/Planner/plans/PhysicalTableScan.cpp b/dbms/src/Flash/Planner/plans/PhysicalTableScan.cpp index 06710f4dc98..ab8c60a4ba3 100644 --- a/dbms/src/Flash/Planner/plans/PhysicalTableScan.cpp +++ b/dbms/src/Flash/Planner/plans/PhysicalTableScan.cpp @@ -17,7 +17,6 @@ #include #include #include -#include #include #include #include diff --git a/dbms/src/Flash/tests/gtest_compute_server.cpp b/dbms/src/Flash/tests/gtest_compute_server.cpp index 264db3ea876..ab53fe00392 100644 --- a/dbms/src/Flash/tests/gtest_compute_server.cpp +++ b/dbms/src/Flash/tests/gtest_compute_server.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include namespace DB @@ -34,6 +35,24 @@ class ComputeServerRunner : public DB::tests::MPPTaskTestUtils {{"s1", TiDB::TP::TypeLong}, {"s2", TiDB::TP::TypeString}, {"s3", TiDB::TP::TypeString}}, {toNullableVec("s1", {1, {}, 10000000, 10000000}), toNullableVec("s2", {"apple", {}, "banana", "test"}), toNullableVec("s3", {"apple", {}, "banana", "test"})}); + /// agg table with 200 rows + std::vector::FieldType>> agg_s1(200); + std::vector> agg_s2(200); + std::vector> agg_s3(200); + for (size_t i = 0; i < 200; ++i) + { + if (i % 30 != 0) + { + agg_s1[i] = i % 20; + agg_s2[i] = {fmt::format("val_{}", i % 10)}; + agg_s3[i] = {fmt::format("val_{}", i)}; + } + } + context.addMockTable( + {"test_db", "test_table_2"}, + {{"s1", TiDB::TP::TypeLong}, {"s2", TiDB::TP::TypeString}, {"s3", TiDB::TP::TypeString}}, + {toNullableVec("s1", agg_s1), toNullableVec("s2", agg_s2), toNullableVec("s3", agg_s3)}); + /// for join context.addMockTable( {"test_db", "l_table"}, @@ -43,9 +62,46 @@ class ComputeServerRunner : public DB::tests::MPPTaskTestUtils {"test_db", "r_table"}, {{"s", TiDB::TP::TypeString}, {"join_c", TiDB::TP::TypeString}}, {toNullableVec("s", {"banana", {}, "banana"}), toNullableVec("join_c", {"apple", {}, "banana"})}); + + /// join left table with 200 rows + std::vector::FieldType>> join_s1(200); + std::vector> join_s2(200); + std::vector> join_s3(200); + for (size_t i = 0; i < 200; ++i) + { + if (i % 20 != 0) + { + agg_s1[i] = i % 5; + agg_s2[i] = {fmt::format("val_{}", i % 6)}; + agg_s3[i] = {fmt::format("val_{}", i)}; + } + } + context.addMockTable( + {"test_db", "l_table_2"}, + {{"s1", TiDB::TP::TypeLong}, {"s2", TiDB::TP::TypeString}, {"s3", TiDB::TP::TypeString}}, + {toNullableVec("s1", agg_s1), toNullableVec("s2", agg_s2), toNullableVec("s3", agg_s3)}); + + /// join right table with 100 rows + std::vector::FieldType>> join_r_s1(100); + std::vector> join_r_s2(100); + std::vector> join_r_s3(100); + for (size_t i = 0; i < 100; ++i) + { + if (i % 20 != 0) + { + join_r_s1[i] = i % 6; + join_r_s2[i] = {fmt::format("val_{}", i % 7)}; + join_r_s3[i] = {fmt::format("val_{}", i)}; + } + } + context.addMockTable( + {"test_db", "r_table_2"}, + {{"s1", TiDB::TP::TypeLong}, {"s2", TiDB::TP::TypeString}, {"s3", TiDB::TP::TypeString}}, + {toNullableVec("s1", join_r_s1), toNullableVec("s2", join_r_s2), toNullableVec("s3", join_r_s3)}); } }; + TEST_F(ComputeServerRunner, runAggTasks) try { @@ -290,6 +346,53 @@ try } CATCH +TEST_F(ComputeServerRunner, aggWithColumnPrune) +try +{ + startServers(3); + + context.addMockTable( + {"test_db", "test_table_2"}, + {{"i1", TiDB::TP::TypeLong}, {"i2", TiDB::TP::TypeLong}, {"s1", TiDB::TP::TypeString}, {"s2", TiDB::TP::TypeString}, {"s3", TiDB::TP::TypeString}, {"s4", TiDB::TP::TypeString}, {"s5", TiDB::TP::TypeString}}, + {toNullableVec("i1", {0, 0, 0}), toNullableVec("i2", {1, 1, 1}), toNullableVec("s1", {"1", "9", "8"}), toNullableVec("s2", {"1", "9", "8"}), toNullableVec("s3", {"4", "9", "99"}), toNullableVec("s4", {"4", "9", "999"}), toNullableVec("s5", {"4", "9", "9999"})}); + std::vector res{"9", "9", "99", "999", "9999"}; + std::vector max_cols{"s1", "s2", "s3", "s4", "s5"}; + for (size_t i = 0; i < 1; ++i) + { + { + auto request = context + .scan("test_db", "test_table_2") + .aggregation({Max(col(max_cols[i]))}, {col("i1")}); + auto expected_cols = { + toNullableVec({res[i]}), + toNullableVec({{0}})}; + ASSERT_COLUMNS_EQ_UR(expected_cols, buildAndExecuteMPPTasks(request)); + } + + { + auto request = context + .scan("test_db", "test_table_2") + .aggregation({Max(col(max_cols[i]))}, {col("i2")}); + auto expected_cols = { + toNullableVec({res[i]}), + toNullableVec({{1}})}; + ASSERT_COLUMNS_EQ_UR(expected_cols, buildAndExecuteMPPTasks(request)); + } + + { + auto request = context + .scan("test_db", "test_table_2") + .aggregation({Max(col(max_cols[i]))}, {col("i1"), col("i2")}); + auto expected_cols = { + toNullableVec({res[i]}), + toNullableVec({{0}}), + toNullableVec({{1}})}; + ASSERT_COLUMNS_EQ_UR(expected_cols, buildAndExecuteMPPTasks(request)); + } + } +} +CATCH + TEST_F(ComputeServerRunner, cancelAggTasks) try { @@ -398,5 +501,121 @@ try } } CATCH + +/// For FineGrainedShuffleJoin/Agg test usage, update internal exchange senders/receivers flag +/// Allow select,agg,join,tableScan,exchangeSender,exchangeReceiver,projection executors only +void setFineGrainedShuffleForExchange(tipb::Executor & root) +{ + tipb::Executor * current = &root; + while (current) + { + switch (current->tp()) + { + case tipb::ExecType::TypeSelection: + current = const_cast(¤t->selection().child()); + break; + case tipb::ExecType::TypeAggregation: + current = const_cast(¤t->aggregation().child()); + break; + case tipb::ExecType::TypeProjection: + current = const_cast(¤t->projection().child()); + break; + case tipb::ExecType::TypeJoin: + { + /// update build side path + JoinInterpreterHelper::TiFlashJoin tiflash_join{current->join()}; + current = const_cast(¤t->join().children()[tiflash_join.build_side_index]); + break; + } + case tipb::ExecType::TypeExchangeSender: + if (current->exchange_sender().tp() == tipb::Hash) + current->set_fine_grained_shuffle_stream_count(8); + current = const_cast(¤t->exchange_sender().child()); + break; + case tipb::ExecType::TypeExchangeReceiver: + current->set_fine_grained_shuffle_stream_count(8); + current = nullptr; + break; + case tipb::ExecType::TypeTableScan: + current = nullptr; + break; + default: + throw TiFlashException("Should not reach here", Errors::Coprocessor::Internal); + } + } +} + +TEST_F(ComputeServerRunner, runFineGrainedShuffleJoinTest) +try +{ + startServers(3); + constexpr size_t join_type_num = 7; + constexpr tipb::JoinType join_types[join_type_num] = { + tipb::JoinType::TypeInnerJoin, + tipb::JoinType::TypeLeftOuterJoin, + tipb::JoinType::TypeRightOuterJoin, + tipb::JoinType::TypeSemiJoin, + tipb::JoinType::TypeAntiSemiJoin, + tipb::JoinType::TypeLeftOuterSemiJoin, + tipb::JoinType::TypeAntiLeftOuterSemiJoin, + }; + // fine-grained shuffle is enabled. + constexpr uint64_t enable = 8; + constexpr uint64_t disable = 0; + + for (auto join_type : join_types) + { + std::cout << "JoinType: " << static_cast(join_type) << std::endl; + auto properties = DB::tests::getDAGPropertiesForTest(serverNum()); + auto request = context + .scan("test_db", "l_table_2") + .join(context.scan("test_db", "r_table_2"), join_type, {col("s1"), col("s2")}, disable) + .project({col("l_table_2.s1"), col("l_table_2.s2"), col("l_table_2.s3")}); + const auto expected_cols = buildAndExecuteMPPTasks(request); + + auto request2 = context + .scan("test_db", "l_table_2") + .join(context.scan("test_db", "r_table_2"), join_type, {col("s1"), col("s2")}, enable) + .project({col("l_table_2.s1"), col("l_table_2.s2"), col("l_table_2.s3")}); + auto tasks = request2.buildMPPTasks(context, properties); + for (auto & task : tasks) + { + setFineGrainedShuffleForExchange(const_cast(task.dag_request->root_executor())); + } + const auto actual_cols = executeMPPTasks(tasks, properties, MockComputeServerManager::instance().getServerConfigMap()); + ASSERT_COLUMNS_EQ_UR(expected_cols, actual_cols); + } +} +CATCH + +TEST_F(ComputeServerRunner, runFineGrainedShuffleAggTest) +try +{ + startServers(3); + // fine-grained shuffle is enabled. + constexpr uint64_t enable = 8; + constexpr uint64_t disable = 0; + { + auto properties = DB::tests::getDAGPropertiesForTest(serverNum()); + auto request = context + .scan("test_db", "test_table_2") + .aggregation({Max(col("s3"))}, {col("s1"), col("s2")}, disable); + const auto expected_cols = buildAndExecuteMPPTasks(request); + + auto request2 = context + .scan("test_db", "test_table_2") + .aggregation({Max(col("s3"))}, {col("s1"), col("s2")}, enable); + auto tasks = request2.buildMPPTasks(context, properties); + for (auto & task : tasks) + { + setFineGrainedShuffleForExchange(const_cast(task.dag_request->root_executor())); + } + + const auto actual_cols = executeMPPTasks(tasks, properties, MockComputeServerManager::instance().getServerConfigMap()); + ASSERT_COLUMNS_EQ_UR(expected_cols, actual_cols); + } +} +CATCH + } // namespace tests } // namespace DB diff --git a/dbms/src/Flash/tests/gtest_execution_summary.cpp b/dbms/src/Flash/tests/gtest_execution_summary.cpp new file mode 100644 index 00000000000..e010cfb5a53 --- /dev/null +++ b/dbms/src/Flash/tests/gtest_execution_summary.cpp @@ -0,0 +1,140 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +namespace DB +{ +namespace tests +{ +class ExecutionSummaryTestRunner : public DB::tests::ExecutorTest +{ +public: + void initializeContext() override + { + ExecutorTest::initializeContext(); + context.addMockTable({"test_db", "test_table"}, + {{"s1", TiDB::TP::TypeString}, {"s2", TiDB::TP::TypeString}}, + {toNullableVec("s1", {"banana", {}, "banana", "banana", {}, "banana", "banana", {}, "banana", "banana", {}, "banana"}), + toNullableVec("s2", {"apple", {}, "banana", "apple", {}, "banana", "apple", {}, "banana", "apple", {}, "banana"})}); + context.addExchangeReceiver("test_exchange", + {{"s1", TiDB::TP::TypeString}, {"s2", TiDB::TP::TypeString}}, + {toNullableVec("s1", {"banana", {}, "banana", "banana", {}, "banana", "banana", {}, "banana", "banana", {}, "banana"}), + toNullableVec("s2", {"apple", {}, "banana", "apple", {}, "banana", "apple", {}, "banana", "apple", {}, "banana"})}); + } + + static constexpr size_t concurrency = 10; + static constexpr int not_check_rows = -1; + // + using ProfileInfo = std::pair; + using Expect = std::unordered_map; + void testForExecutionSummary( + const std::shared_ptr & request, + const Expect & expect) + { + request->set_collect_execution_summaries(true); + DAGContext dag_context(*request, "test_execution_summary", concurrency); + executeStreams(&dag_context); + ASSERT_EQ(dag_context.getProfileStreamsMap().size(), expect.size()); + ASSERT_TRUE(dag_context.collect_execution_summaries); + ExecutionSummaryCollector summary_collector(dag_context); + auto summaries = summary_collector.genExecutionSummaryResponse().execution_summaries(); + ASSERT_EQ(summaries.size(), expect.size()); + for (const auto & summary : summaries) + { + ASSERT_TRUE(summary.has_executor_id()); + auto it = expect.find(summary.executor_id()); + ASSERT_TRUE(it != expect.end()) << fmt::format("unknown executor_id: {}", summary.executor_id()); + if (it->second.first != not_check_rows) + ASSERT_EQ(summary.num_produced_rows(), it->second.first) << fmt::format("executor_id: {}", summary.executor_id()); + ASSERT_EQ(summary.concurrency(), it->second.second) << fmt::format("executor_id: {}", summary.executor_id()); + // time_processed_ns, num_iterations and tiflash_scan_context are not checked here. + } + } +}; + +TEST_F(ExecutionSummaryTestRunner, test) +try +{ + { + auto request = context + .scan("test_db", "test_table") + .filter(eq(col("s1"), col("s2"))) + .build(context); + Expect expect{{"table_scan_0", {12, concurrency}}, {"selection_1", {4, concurrency}}}; + testForExecutionSummary(request, expect); + } + { + auto request = context + .scan("test_db", "test_table") + .limit(5) + .build(context); + Expect expect{{"table_scan_0", {not_check_rows, concurrency}}, {"limit_1", {5, 1}}}; + testForExecutionSummary(request, expect); + } + { + auto request = context + .scan("test_db", "test_table") + .topN("s1", true, 5) + .build(context); + Expect expect{{"table_scan_0", {not_check_rows, concurrency}}, {"topn_1", {5, 1}}}; + testForExecutionSummary(request, expect); + } + { + auto request = context + .scan("test_db", "test_table") + .project({col("s2")}) + .build(context); + Expect expect{{"table_scan_0", {12, concurrency}}, {"project_1", {12, concurrency}}}; + testForExecutionSummary(request, expect); + } + { + auto request = context + .scan("test_db", "test_table") + .aggregation({col("s2")}, {col("s2")}) + .build(context); + Expect expect{{"table_scan_0", {12, concurrency}}, {"aggregation_1", {3, concurrency}}}; + testForExecutionSummary(request, expect); + } + { + auto t1 = context.scan("test_db", "test_table"); + auto t2 = context.scan("test_db", "test_table"); + auto request = t1.join(t2, tipb::JoinType::TypeInnerJoin, {col("s1")}).build(context); + Expect expect{{"table_scan_0", {12, concurrency}}, {"table_scan_1", {12, concurrency}}, {"Join_2", {64, concurrency}}}; + testForExecutionSummary(request, expect); + } + { + auto request = context + .receive("test_exchange") + .exchangeSender(tipb::Hash) + .build(context); + Expect expect{{"exchange_receiver_0", {12, concurrency}}, {"exchange_sender_1", {12, concurrency}}}; + testForExecutionSummary(request, expect); + } + { + auto request = context + .receive("test_exchange") + .sort({{"s1", false}, {"s2", false}, {"s1", false}, {"s2", false}}, true) + .window(RowNumber(), {"s1", false}, {"s2", false}, buildDefaultRowsFrame()) + .build(context); + Expect expect{{"exchange_receiver_0", {12, concurrency}}, {"sort_1", {12, 1}}, {"window_2", {12, 1}}}; + testForExecutionSummary(request, expect); + } +} +CATCH + +} // namespace tests +} // namespace DB diff --git a/dbms/src/Flash/tests/gtest_interpreter.cpp b/dbms/src/Flash/tests/gtest_interpreter.cpp index 736166929bc..0afa65390ac 100644 --- a/dbms/src/Flash/tests/gtest_interpreter.cpp +++ b/dbms/src/Flash/tests/gtest_interpreter.cpp @@ -391,12 +391,110 @@ Union: } CATCH +TEST_F(InterpreterExecuteTest, FineGrainedShuffleJoin) +try +{ + // fine-grained shuffle is enabled. + const uint64_t enable = 8; + const uint64_t disable = 0; + { + // Join Source. + DAGRequestBuilder receiver1 = context.receive("sender_l"); + DAGRequestBuilder receiver2 = context.receive("sender_r", enable); + + auto request = receiver1.join( + receiver2, + tipb::JoinType::TypeLeftOuterJoin, + {col("join_c")}, + enable) + .build(context); + + String expected = R"( +CreatingSets + Union: + HashJoinBuild x 10: , join_kind = Left + Expression: + Expression: + MockExchangeReceiver + Union: + Expression x 10: + Expression: + HashJoinProbe: + Expression: + MockExchangeReceiver)"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); + } + { + // Join Source. + DAGRequestBuilder receiver1 = context.receive("sender_l"); + DAGRequestBuilder receiver2 = context.receive("sender_r", disable); + + auto request = receiver1.join( + receiver2, + tipb::JoinType::TypeLeftOuterJoin, + {col("join_c")}, + disable) + .build(context); + + String expected = R"( +CreatingSets + Union: + HashJoinBuild x 10: , join_kind = Left + Expression: + Expression: + MockExchangeReceiver + Union: + Expression x 10: + Expression: + HashJoinProbe: + Expression: + MockExchangeReceiver)"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); + } +} +CATCH + +TEST_F(InterpreterExecuteTest, FineGrainedShuffleAgg) +try +{ + // fine-grained shuffle is enabled. + const uint64_t enable = 8; + const uint64_t disable = 0; + { + DAGRequestBuilder receiver1 = context.receive("sender_1", enable); + auto request = receiver1 + .aggregation({Max(col("s1"))}, {col("s2")}, enable) + .build(context); + String expected = R"( +Union: + Expression x 10: + Aggregating: + MockExchangeReceiver)"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); + } + + { + DAGRequestBuilder receiver1 = context.receive("sender_1", disable); + auto request = receiver1 + .aggregation({Max(col("s1"))}, {col("s2")}, disable) + .build(context); + String expected = R"( +Union: + Expression x 10: + SharedQuery: + ParallelAggregating, max_threads: 10, final: true + MockExchangeReceiver x 10)"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); + } +} +CATCH + TEST_F(InterpreterExecuteTest, Join) try { // TODO: Find a way to write the request easier. { - // Join Source. + // join + ExchangeReceiver DAGRequestBuilder table1 = context.scan("test_db", "r_table"); DAGRequestBuilder table2 = context.scan("test_db", "l_table"); DAGRequestBuilder table3 = context.scan("test_db", "r_table"); diff --git a/dbms/src/Flash/tests/gtest_planner_interpreter.cpp b/dbms/src/Flash/tests/gtest_planner_interpreter.cpp index e9f99891642..eb6de71ca4e 100644 --- a/dbms/src/Flash/tests/gtest_planner_interpreter.cpp +++ b/dbms/src/Flash/tests/gtest_planner_interpreter.cpp @@ -723,6 +723,108 @@ Union: } CATCH +TEST_F(PlannerInterpreterExecuteTest, FineGrainedShuffleJoin) +try +{ + // fine-grained shuffle is enabled. + const uint64_t enable = 8; + const uint64_t disable = 0; + { + // Join Source. + DAGRequestBuilder receiver1 = context.receive("sender_l"); + DAGRequestBuilder receiver2 = context.receive("sender_r", enable); + + auto request = receiver1.join( + receiver2, + tipb::JoinType::TypeLeftOuterJoin, + {col("join_c")}, + enable) + .build(context); + + String expected = R"( +CreatingSets + Union: + HashJoinBuild x 10: , join_kind = Left + Expression: + Expression: + MockExchangeReceiver + Union: + Expression x 10: + Expression: + HashJoinProbe: + Expression: + MockExchangeReceiver)"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); + } + { + // Join Source. + DAGRequestBuilder receiver1 = context.receive("sender_l"); + DAGRequestBuilder receiver2 = context.receive("sender_r", disable); + + auto request = receiver1.join( + receiver2, + tipb::JoinType::TypeLeftOuterJoin, + {col("join_c")}, + disable) + .build(context); + + String expected = R"( +CreatingSets + Union: + HashJoinBuild x 10: , join_kind = Left + Expression: + Expression: + MockExchangeReceiver + Union: + Expression x 10: + Expression: + HashJoinProbe: + Expression: + MockExchangeReceiver)"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); + } +} +CATCH + +TEST_F(PlannerInterpreterExecuteTest, FineGrainedShuffleAgg) +try +{ + // fine-grained shuffle is enabled. + const uint64_t enable = 8; + const uint64_t disable = 0; + { + DAGRequestBuilder receiver1 = context.receive("sender_1", enable); + auto request = receiver1 + .aggregation({Max(col("s1"))}, {col("s2")}, enable) + .build(context); + String expected = R"( +Union: + Expression x 10: + Expression: + Aggregating: + Expression: + MockExchangeReceiver)"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); + } + + { + DAGRequestBuilder receiver1 = context.receive("sender_1", disable); + auto request = receiver1 + .aggregation({Max(col("s1"))}, {col("s2")}, disable) + .build(context); + String expected = R"( +Union: + Expression x 10: + Expression: + SharedQuery: + ParallelAggregating, max_threads: 10, final: true + Expression x 10: + MockExchangeReceiver)"; + ASSERT_BLOCKINPUTSTREAM_EQAUL(expected, request, 10); + } +} +CATCH + TEST_F(PlannerInterpreterExecuteTest, Join) try { diff --git a/dbms/src/Interpreters/Aggregator.cpp b/dbms/src/Interpreters/Aggregator.cpp index b7193833031..4a9ffc1c993 100644 --- a/dbms/src/Interpreters/Aggregator.cpp +++ b/dbms/src/Interpreters/Aggregator.cpp @@ -32,6 +32,7 @@ #include #include #include +#include #include #include @@ -56,7 +57,6 @@ extern const int LOGICAL_ERROR; namespace FailPoints { extern const char random_aggregate_create_state_failpoint[]; -extern const char random_aggregate_merge_failpoint[]; } // namespace FailPoints #define AggregationMethodName(NAME) AggregatedDataVariants::AggregationMethod_##NAME @@ -1773,223 +1773,6 @@ void NO_INLINE Aggregator::mergeBucketImpl( } -/** Combines aggregation states together, turns them into blocks, and outputs streams. - * If the aggregation states are two-level, then it produces blocks strictly in order of 'bucket_num'. - * (This is important for distributed processing.) - * In doing so, it can handle different buckets in parallel, using up to `threads` threads. - */ -class MergingAndConvertingBlockInputStream : public IProfilingBlockInputStream -{ -public: - /** The input is a set of non-empty sets of partially aggregated data, - * which are all either single-level, or are two-level. - */ - MergingAndConvertingBlockInputStream(const Aggregator & aggregator_, ManyAggregatedDataVariants & data_, bool final_, size_t threads_) - : log(Logger::get(aggregator_.log ? aggregator_.log->identifier() : "")) - , aggregator(aggregator_) - , data(data_) - , final(final_) - , threads(threads_) - { - /// At least we need one arena in first data item per thread - if (!data.empty() && threads > data[0]->aggregates_pools.size()) - { - Arenas & first_pool = data[0]->aggregates_pools; - for (size_t j = first_pool.size(); j < threads; ++j) - first_pool.emplace_back(std::make_shared()); - } - } - - String getName() const override { return "MergingAndConverting"; } - - Block getHeader() const override { return aggregator.getHeader(final); } - - ~MergingAndConvertingBlockInputStream() override - { - LOG_TRACE(&Poco::Logger::get(__PRETTY_FUNCTION__), "Waiting for threads to finish"); - - /// We need to wait for threads to finish before destructor of 'parallel_merge_data', - /// because the threads access 'parallel_merge_data'. - if (parallel_merge_data && parallel_merge_data->thread_pool) - parallel_merge_data->thread_pool->wait(); - } - -protected: - Block readImpl() override - { - if (data.empty()) - return {}; - - if (current_bucket_num >= NUM_BUCKETS) - return {}; - - FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::random_aggregate_merge_failpoint); - - AggregatedDataVariantsPtr & first = data[0]; - - if (current_bucket_num == -1) - { - ++current_bucket_num; - - if (first->type == AggregatedDataVariants::Type::without_key || aggregator.params.overflow_row) - { - aggregator.mergeWithoutKeyDataImpl(data); - return aggregator.prepareBlockAndFillWithoutKey( - *first, - final, - first->type != AggregatedDataVariants::Type::without_key); - } - } - - if (!first->isTwoLevel()) - { - if (current_bucket_num > 0) - return {}; - - if (first->type == AggregatedDataVariants::Type::without_key) - return {}; - - ++current_bucket_num; - -#define M(NAME) \ - case AggregationMethodType(NAME): \ - { \ - aggregator.mergeSingleLevelDataImpl(data); \ - break; \ - } - switch (first->type) - { - APPLY_FOR_VARIANTS_SINGLE_LEVEL(M) - default: - throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); - } -#undef M - return aggregator.prepareBlockAndFillSingleLevel(*first, final); - } - else - { - if (!parallel_merge_data) - { - parallel_merge_data = std::make_unique(threads); - for (size_t i = 0; i < threads; ++i) - scheduleThreadForNextBucket(); - } - - Block res; - - while (true) - { - std::unique_lock lock(parallel_merge_data->mutex); - - if (parallel_merge_data->exception) - std::rethrow_exception(parallel_merge_data->exception); - - auto it = parallel_merge_data->ready_blocks.find(current_bucket_num); - if (it != parallel_merge_data->ready_blocks.end()) - { - ++current_bucket_num; - scheduleThreadForNextBucket(); - - if (it->second) - { - res.swap(it->second); - break; - } - else if (current_bucket_num >= NUM_BUCKETS) - break; - } - - parallel_merge_data->condvar.wait(lock); - } - - return res; - } - } - -private: - const LoggerPtr log; - const Aggregator & aggregator; - ManyAggregatedDataVariants data; - bool final; - size_t threads; - - std::atomic current_bucket_num = -1; - std::atomic max_scheduled_bucket_num = -1; - static constexpr Int32 NUM_BUCKETS = 256; - - struct ParallelMergeData - { - std::map ready_blocks; - std::exception_ptr exception; - std::mutex mutex; - std::condition_variable condvar; - std::shared_ptr thread_pool; - - explicit ParallelMergeData(size_t threads) - : thread_pool(newThreadPoolManager(threads)) - {} - }; - - std::unique_ptr parallel_merge_data; - - void scheduleThreadForNextBucket() - { - int num = max_scheduled_bucket_num.fetch_add(1) + 1; - if (num >= NUM_BUCKETS) - return; - - parallel_merge_data->thread_pool->schedule(true, [this, num] { thread(num); }); - } - - void thread(Int32 bucket_num) - { - try - { - /// TODO: add no_more_keys support maybe - - auto & merged_data = *data[0]; - auto method = merged_data.type; - Block block; - - /// Select Arena to avoid race conditions - size_t thread_number = static_cast(bucket_num) % threads; - Arena * arena = merged_data.aggregates_pools.at(thread_number).get(); - -#define M(NAME) \ - case AggregationMethodType(NAME): \ - { \ - aggregator.mergeBucketImpl(data, bucket_num, arena); \ - block = aggregator.convertOneBucketToBlock( \ - merged_data, \ - *ToAggregationMethodPtr(NAME, merged_data.aggregation_method_impl), \ - arena, \ - final, \ - bucket_num); \ - break; \ - } - switch (method) - { - APPLY_FOR_VARIANTS_TWO_LEVEL(M) - default: - break; - } -#undef M - - std::lock_guard lock(parallel_merge_data->mutex); - parallel_merge_data->ready_blocks[bucket_num] = std::move(block); - } - catch (...) - { - std::lock_guard lock(parallel_merge_data->mutex); - if (!parallel_merge_data->exception) - parallel_merge_data->exception = std::current_exception(); - } - - parallel_merge_data->condvar.notify_all(); - } -}; - - std::unique_ptr Aggregator::mergeAndConvertToBlocks( ManyAggregatedDataVariants & data_variants, bool final, @@ -2751,5 +2534,11 @@ void Aggregator::setCancellationHook(CancellationHook cancellation_hook) is_cancelled = cancellation_hook; } +#undef AggregationMethodName +#undef AggregationMethodNameTwoLevel +#undef AggregationMethodType +#undef AggregationMethodTypeTwoLevel +#undef ToAggregationMethodPtr +#undef ToAggregationMethodPtrTwoLevel } // namespace DB diff --git a/dbms/src/Interpreters/InterpreterSelectQuery.cpp b/dbms/src/Interpreters/InterpreterSelectQuery.cpp index 47571a6f860..f1a39672a99 100644 --- a/dbms/src/Interpreters/InterpreterSelectQuery.cpp +++ b/dbms/src/Interpreters/InterpreterSelectQuery.cpp @@ -27,7 +27,6 @@ #include #include #include -#include #include #include #include @@ -616,12 +615,6 @@ void InterpreterSelectQuery::executeImpl(Pipeline & pipeline, const BlockInputSt if (need_second_distinct_pass) executeDistinct(pipeline, false, expressions.selected_columns); - if (expressions.has_limit_by) - { - executeExpression(pipeline, expressions.before_limit_by); - executeLimitBy(pipeline); - } - /** We must do projection after DISTINCT because projection may remove some columns. */ executeProjection(pipeline, expressions.final_projection); @@ -1291,23 +1284,6 @@ void InterpreterSelectQuery::executePreLimit(Pipeline & pipeline) } -void InterpreterSelectQuery::executeLimitBy(Pipeline & pipeline) // NOLINT -{ - if (!query.limit_by_value || !query.limit_by_expression_list) - return; - - Names columns; - for (const auto & elem : query.limit_by_expression_list->children) - columns.emplace_back(elem->getColumnName()); - - auto value = safeGet(typeid_cast(*query.limit_by_value).value); - - pipeline.transform([&](auto & stream) { - stream = std::make_shared(stream, value, columns); - }); -} - - bool hasWithTotalsInAnySubqueryInFromClause(const ASTSelectQuery & query) { if (query.group_by_with_totals) diff --git a/dbms/src/Interpreters/Join.cpp b/dbms/src/Interpreters/Join.cpp index 15d94e0c195..102fa05f9f6 100644 --- a/dbms/src/Interpreters/Join.cpp +++ b/dbms/src/Interpreters/Join.cpp @@ -1214,7 +1214,10 @@ void NO_INLINE joinBlockImplTypeCase( /// 2. In ExchangeReceiver, build_stream_id = packet_stream_id % build_stream_count; /// 3. In HashBuild, build_concurrency decides map's segment size, and build_steam_id decides the segment index auto packet_stream_id = shuffle_hash_data[i] % fine_grained_shuffle_count; - segment_index = packet_stream_id % segment_size; + if likely (fine_grained_shuffle_count == segment_size) + segment_index = packet_stream_id; + else + segment_index = packet_stream_id % segment_size; } else { diff --git a/dbms/src/Interpreters/MergingAndConvertingBlockInputStream.h b/dbms/src/Interpreters/MergingAndConvertingBlockInputStream.h new file mode 100644 index 00000000000..7e58eb8da81 --- /dev/null +++ b/dbms/src/Interpreters/MergingAndConvertingBlockInputStream.h @@ -0,0 +1,252 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +namespace DB +{ +namespace FailPoints +{ +extern const char random_aggregate_merge_failpoint[]; +} // namespace FailPoints + +#define AggregationMethodName(NAME) AggregatedDataVariants::AggregationMethod_##NAME +#define AggregationMethodType(NAME) AggregatedDataVariants::Type::NAME +#define ToAggregationMethodPtr(NAME, ptr) (reinterpret_cast(ptr)) + +/** Combines aggregation states together, turns them into blocks, and outputs streams. + * If the aggregation states are two-level, then it produces blocks strictly in order of 'bucket_num'. + * (This is important for distributed processing.) + * In doing so, it can handle different buckets in parallel, using up to `threads` threads. + */ +class MergingAndConvertingBlockInputStream : public IProfilingBlockInputStream +{ +public: + /** The input is a set of non-empty sets of partially aggregated data, + * which are all either single-level, or are two-level. + */ + MergingAndConvertingBlockInputStream(const Aggregator & aggregator_, ManyAggregatedDataVariants & data_, bool final_, size_t threads_) + : log(Logger::get(aggregator_.log ? aggregator_.log->identifier() : "")) + , aggregator(aggregator_) + , data(data_) + , final(final_) + , threads(threads_) + { + /// At least we need one arena in first data item per thread + if (!data.empty() && threads > data[0]->aggregates_pools.size()) + { + Arenas & first_pool = data[0]->aggregates_pools; + for (size_t j = first_pool.size(); j < threads; ++j) + first_pool.emplace_back(std::make_shared()); + } + } + + String getName() const override { return "MergingAndConverting"; } + + Block getHeader() const override { return aggregator.getHeader(final); } + + ~MergingAndConvertingBlockInputStream() override + { + LOG_TRACE(&Poco::Logger::get(__PRETTY_FUNCTION__), "Waiting for threads to finish"); + + /// We need to wait for threads to finish before destructor of 'parallel_merge_data', + /// because the threads access 'parallel_merge_data'. + if (parallel_merge_data && parallel_merge_data->thread_pool) + parallel_merge_data->thread_pool->wait(); + } + +protected: + Block readImpl() override + { + if (data.empty()) + return {}; + + if (current_bucket_num >= NUM_BUCKETS) + return {}; + + FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::random_aggregate_merge_failpoint); + + AggregatedDataVariantsPtr & first = data[0]; + + if (current_bucket_num == -1) + { + ++current_bucket_num; + + if (first->type == AggregatedDataVariants::Type::without_key || aggregator.params.overflow_row) + { + aggregator.mergeWithoutKeyDataImpl(data); + return aggregator.prepareBlockAndFillWithoutKey( + *first, + final, + first->type != AggregatedDataVariants::Type::without_key); + } + } + + if (!first->isTwoLevel()) + { + if (current_bucket_num > 0) + return {}; + + if (first->type == AggregatedDataVariants::Type::without_key) + return {}; + + ++current_bucket_num; + +#define M(NAME) \ + case AggregationMethodType(NAME): \ + { \ + aggregator.mergeSingleLevelDataImpl(data); \ + break; \ + } + switch (first->type) + { + APPLY_FOR_VARIANTS_SINGLE_LEVEL(M) + default: + throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT); + } +#undef M + return aggregator.prepareBlockAndFillSingleLevel(*first, final); + } + else + { + if (!parallel_merge_data) + { + parallel_merge_data = std::make_unique(threads); + for (size_t i = 0; i < threads; ++i) + scheduleThreadForNextBucket(); + } + + Block res; + + while (true) + { + std::unique_lock lock(parallel_merge_data->mutex); + + if (parallel_merge_data->exception) + std::rethrow_exception(parallel_merge_data->exception); + + auto it = parallel_merge_data->ready_blocks.find(current_bucket_num); + if (it != parallel_merge_data->ready_blocks.end()) + { + ++current_bucket_num; + scheduleThreadForNextBucket(); + + if (it->second) + { + res.swap(it->second); + break; + } + else if (current_bucket_num >= NUM_BUCKETS) + break; + } + + parallel_merge_data->condvar.wait(lock); + } + + return res; + } + } + +private: + const LoggerPtr log; + const Aggregator & aggregator; + ManyAggregatedDataVariants data; + bool final; + size_t threads; + + std::atomic current_bucket_num = -1; + std::atomic max_scheduled_bucket_num = -1; + static constexpr Int32 NUM_BUCKETS = 256; + + struct ParallelMergeData + { + std::map ready_blocks; + std::exception_ptr exception; + std::mutex mutex; + std::condition_variable condvar; + std::shared_ptr thread_pool; + + explicit ParallelMergeData(size_t threads) + : thread_pool(newThreadPoolManager(threads)) + {} + }; + + std::unique_ptr parallel_merge_data; + + void scheduleThreadForNextBucket() + { + int num = max_scheduled_bucket_num.fetch_add(1) + 1; + if (num >= NUM_BUCKETS) + return; + + parallel_merge_data->thread_pool->schedule(true, [this, num] { thread(num); }); + } + + void thread(Int32 bucket_num) + { + try + { + /// TODO: add no_more_keys support maybe + + auto & merged_data = *data[0]; + auto method = merged_data.type; + Block block; + + /// Select Arena to avoid race conditions + size_t thread_number = static_cast(bucket_num) % threads; + Arena * arena = merged_data.aggregates_pools.at(thread_number).get(); + +#define M(NAME) \ + case AggregationMethodType(NAME): \ + { \ + aggregator.mergeBucketImpl(data, bucket_num, arena); \ + block = aggregator.convertOneBucketToBlock( \ + merged_data, \ + *ToAggregationMethodPtr(NAME, merged_data.aggregation_method_impl), \ + arena, \ + final, \ + bucket_num); \ + break; \ + } + switch (method) + { + APPLY_FOR_VARIANTS_TWO_LEVEL(M) + default: + break; + } +#undef M + + std::lock_guard lock(parallel_merge_data->mutex); + parallel_merge_data->ready_blocks[bucket_num] = std::move(block); + } + catch (...) + { + std::lock_guard lock(parallel_merge_data->mutex); + if (!parallel_merge_data->exception) + parallel_merge_data->exception = std::current_exception(); + } + + parallel_merge_data->condvar.notify_all(); + } +}; + +#undef AggregationMethodName +#undef AggregationMethodType +#undef ToAggregationMethodPtr +} // namespace DB diff --git a/dbms/src/Server/FlashGrpcServerHolder.cpp b/dbms/src/Server/FlashGrpcServerHolder.cpp index 1aeff0a49c6..0ab13e2dd85 100644 --- a/dbms/src/Server/FlashGrpcServerHolder.cpp +++ b/dbms/src/Server/FlashGrpcServerHolder.cpp @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include diff --git a/dbms/src/Storages/DeltaMerge/ScanContext.h b/dbms/src/Storages/DeltaMerge/ScanContext.h index 590223caef6..63d2081092c 100644 --- a/dbms/src/Storages/DeltaMerge/ScanContext.h +++ b/dbms/src/Storages/DeltaMerge/ScanContext.h @@ -81,6 +81,17 @@ class ScanContext total_dmfile_read_time_ms += other.total_dmfile_read_time_ms; total_create_snapshot_time_ms += other.total_create_snapshot_time_ms; } + + void merge(const tipb::TiFlashScanContext & other) + { + total_dmfile_scanned_packs += other.total_dmfile_scanned_packs(); + total_dmfile_skipped_packs += other.total_dmfile_skipped_packs(); + total_dmfile_scanned_rows += other.total_dmfile_scanned_rows(); + total_dmfile_skipped_rows += other.total_dmfile_skipped_rows(); + total_dmfile_rough_set_index_load_time_ms += other.total_dmfile_rough_set_index_load_time_ms(); + total_dmfile_read_time_ms += other.total_dmfile_read_time_ms(); + total_create_snapshot_time_ms += other.total_create_snapshot_time_ms(); + } }; using ScanContextPtr = std::shared_ptr; diff --git a/dbms/src/Storages/MutableSupport.h b/dbms/src/Storages/MutableSupport.h index e3d3f896970..5ff61994dbf 100644 --- a/dbms/src/Storages/MutableSupport.h +++ b/dbms/src/Storages/MutableSupport.h @@ -52,15 +52,6 @@ class MutableSupport : public ext::Singleton block.erase(it); } - bool shouldWiden(const NameAndTypePair & column) - { - DataTypePtr t - = column.type->isNullable() ? dynamic_cast(column.type.get())->getNestedType() : column.type; - return (column.name != MutableSupport::version_column_name && column.name != MutableSupport::delmark_column_name - && column.name != MutableSupport::tidb_pk_column_name) - && t->isInteger() && !(typeid_cast(t.get()) || typeid_cast(t.get())); - } - static const String mmt_storage_name; static const String txn_storage_name; static const String delta_tree_storage_name; diff --git a/dbms/src/Storages/Transaction/Datum.cpp b/dbms/src/Storages/Transaction/Datum.cpp index b379a448db2..d0d43b60d57 100644 --- a/dbms/src/Storages/Transaction/Datum.cpp +++ b/dbms/src/Storages/Transaction/Datum.cpp @@ -79,7 +79,8 @@ struct DatumOp::type> static bool overflow(const Field &, const ColumnInfo &) { return false; } }; -DatumFlat::DatumFlat(const DB::Field & field, TP tp) : DatumBase(field, tp) +DatumFlat::DatumFlat(const DB::Field & field, TP tp) + : DatumBase(field, tp) { if (orig.isNull()) return; @@ -89,7 +90,7 @@ DatumFlat::DatumFlat(const DB::Field & field, TP tp) : DatumBase(field, tp) #ifdef M #error "Please undefine macro M first." #endif -#define M(tt, v, cf, ct, w) \ +#define M(tt, v, cf, ct) \ case Type##tt: \ DatumOp::unflatten(orig, copy); \ break; @@ -98,7 +99,10 @@ DatumFlat::DatumFlat(const DB::Field & field, TP tp) : DatumBase(field, tp) } } -bool DatumFlat::invalidNull(const ColumnInfo & column_info) { return column_info.hasNotNullFlag() && orig.isNull(); } +bool DatumFlat::invalidNull(const ColumnInfo & column_info) +{ + return column_info.hasNotNullFlag() && orig.isNull(); +} bool DatumFlat::overflow(const ColumnInfo & column_info) { @@ -110,8 +114,8 @@ bool DatumFlat::overflow(const ColumnInfo & column_info) #ifdef M #error "Please undefine macro M first." #endif -#define M(tt, v, cf, ct, w) \ - case Type##tt: \ +#define M(tt, v, cf, ct) \ + case Type##tt: \ return DatumOp::overflow(field(), column_info); COLUMN_TYPES(M) #undef M @@ -120,7 +124,8 @@ bool DatumFlat::overflow(const ColumnInfo & column_info) throw DB::Exception("Shouldn't reach here", DB::ErrorCodes::LOGICAL_ERROR); } -DatumBumpy::DatumBumpy(const DB::Field & field, TP tp) : DatumBase(field, tp) +DatumBumpy::DatumBumpy(const DB::Field & field, TP tp) + : DatumBase(field, tp) { if (orig.isNull()) return; @@ -130,7 +135,7 @@ DatumBumpy::DatumBumpy(const DB::Field & field, TP tp) : DatumBase(field, tp) #ifdef M #error "Please undefine macro M first." #endif -#define M(tt, v, cf, ct, w) \ +#define M(tt, v, cf, ct) \ case Type##tt: \ DatumOp::flatten(orig, copy); \ break; diff --git a/dbms/src/Storages/Transaction/LearnerRead.cpp b/dbms/src/Storages/Transaction/LearnerRead.cpp index bc9cde099fb..7a05b8d4508 100644 --- a/dbms/src/Storages/Transaction/LearnerRead.cpp +++ b/dbms/src/Storages/Transaction/LearnerRead.cpp @@ -219,7 +219,7 @@ LearnerReadSnapshot doLearnerRead( const auto & region_to_query = regions_info[region_idx]; const RegionID region_id = region_to_query.region_id; UInt64 physical_tso = read_index_tso >> TsoPhysicalShiftBits; - bool can_stale_read = physical_tso < region_table.getSelfSafeTS(region_id); + bool can_stale_read = mvcc_query_info->read_tso != std::numeric_limits::max() && physical_tso < region_table.getSelfSafeTS(region_id); if (!can_stale_read) { if (auto ori_read_index = mvcc_query_info.getReadIndexRes(region_id); ori_read_index) diff --git a/dbms/src/Storages/Transaction/TiDB.cpp b/dbms/src/Storages/Transaction/TiDB.cpp index 745839a2476..40f189d9009 100644 --- a/dbms/src/Storages/Transaction/TiDB.cpp +++ b/dbms/src/Storages/Transaction/TiDB.cpp @@ -989,8 +989,8 @@ CodecFlag ColumnInfo::getCodecFlag() const #ifdef M #error "Please undefine macro M first." #endif -#define M(tt, v, cf, ct, w) \ - case Type##tt: \ +#define M(tt, v, cf, ct) \ + case Type##tt: \ return getCodecFlagBase(hasUnsignedFlag()); COLUMN_TYPES(M) #undef M diff --git a/dbms/src/Storages/Transaction/TiDB.h b/dbms/src/Storages/Transaction/TiDB.h index cd428e57e6e..c41a2f5157e 100644 --- a/dbms/src/Storages/Transaction/TiDB.h +++ b/dbms/src/Storages/Transaction/TiDB.h @@ -53,46 +53,46 @@ using DB::Timestamp; // Column types. // In format: -// TiDB type, int value, codec flag, CH type, should widen. +// TiDB type, int value, codec flag, CH type. #ifdef M #error "Please undefine macro M first." #endif -#define COLUMN_TYPES(M) \ - M(Decimal, 0, Decimal, Decimal32, false) \ - M(Tiny, 1, VarInt, Int8, true) \ - M(Short, 2, VarInt, Int16, true) \ - M(Long, 3, VarInt, Int32, true) \ - M(Float, 4, Float, Float32, false) \ - M(Double, 5, Float, Float64, false) \ - M(Null, 6, Nil, Nothing, false) \ - M(Timestamp, 7, UInt, MyDateTime, false) \ - M(LongLong, 8, Int, Int64, false) \ - M(Int24, 9, VarInt, Int32, true) \ - M(Date, 10, UInt, MyDate, false) \ - M(Time, 11, Duration, Int64, false) \ - M(Datetime, 12, UInt, MyDateTime, false) \ - M(Year, 13, Int, Int16, false) \ - M(NewDate, 14, Int, MyDate, false) \ - M(Varchar, 15, CompactBytes, String, false) \ - M(Bit, 16, VarInt, UInt64, false) \ - M(JSON, 0xf5, Json, String, false) \ - M(NewDecimal, 0xf6, Decimal, Decimal32, false) \ - M(Enum, 0xf7, VarUInt, Enum16, false) \ - M(Set, 0xf8, VarUInt, UInt64, false) \ - M(TinyBlob, 0xf9, CompactBytes, String, false) \ - M(MediumBlob, 0xfa, CompactBytes, String, false) \ - M(LongBlob, 0xfb, CompactBytes, String, false) \ - M(Blob, 0xfc, CompactBytes, String, false) \ - M(VarString, 0xfd, CompactBytes, String, false) \ - M(String, 0xfe, CompactBytes, String, false) \ - M(Geometry, 0xff, CompactBytes, String, false) +#define COLUMN_TYPES(M) \ + M(Decimal, 0, Decimal, Decimal32) \ + M(Tiny, 1, VarInt, Int8) \ + M(Short, 2, VarInt, Int16) \ + M(Long, 3, VarInt, Int32) \ + M(Float, 4, Float, Float32) \ + M(Double, 5, Float, Float64) \ + M(Null, 6, Nil, Nothing) \ + M(Timestamp, 7, UInt, MyDateTime) \ + M(LongLong, 8, Int, Int64) \ + M(Int24, 9, VarInt, Int32) \ + M(Date, 10, UInt, MyDate) \ + M(Time, 11, Duration, Int64) \ + M(Datetime, 12, UInt, MyDateTime) \ + M(Year, 13, Int, Int16) \ + M(NewDate, 14, Int, MyDate) \ + M(Varchar, 15, CompactBytes, String) \ + M(Bit, 16, VarInt, UInt64) \ + M(JSON, 0xf5, Json, String) \ + M(NewDecimal, 0xf6, Decimal, Decimal32) \ + M(Enum, 0xf7, VarUInt, Enum16) \ + M(Set, 0xf8, VarUInt, UInt64) \ + M(TinyBlob, 0xf9, CompactBytes, String) \ + M(MediumBlob, 0xfa, CompactBytes, String) \ + M(LongBlob, 0xfb, CompactBytes, String) \ + M(Blob, 0xfc, CompactBytes, String) \ + M(VarString, 0xfd, CompactBytes, String) \ + M(String, 0xfe, CompactBytes, String) \ + M(Geometry, 0xff, CompactBytes, String) enum TP { #ifdef M #error "Please undefine macro M first." #endif -#define M(tt, v, cf, ct, w) Type##tt = (v), +#define M(tt, v, cf, ct) Type##tt = (v), COLUMN_TYPES(M) #undef M }; diff --git a/dbms/src/Storages/Transaction/TypeMapping.cpp b/dbms/src/Storages/Transaction/TypeMapping.cpp index 15f11f4f7d6..256855f17b4 100644 --- a/dbms/src/Storages/Transaction/TypeMapping.cpp +++ b/dbms/src/Storages/Transaction/TypeMapping.cpp @@ -104,22 +104,14 @@ struct EnumType : public std::true_type template inline constexpr bool IsEnumType = EnumType::value; -template +template std::enable_if_t && !IsDecimalType && !IsEnumType && !std::is_same_v, DataTypePtr> getDataTypeByColumnInfoBase(const ColumnInfo &, const T *) { - DataTypePtr t = std::make_shared(); - - if (should_widen) - { - auto widen = t->widen(); - t.swap(widen); - } - - return t; + return std::make_shared(); } -template +template std::enable_if_t, DataTypePtr> getDataTypeByColumnInfoBase(const ColumnInfo & column_info, const T *) { DataTypePtr t = nullptr; @@ -129,57 +121,27 @@ std::enable_if_t, DataTypePtr> getDataTypeByColumnInfoBase(const else t = std::make_shared(); - if (should_widen) - { - auto widen = t->widen(); - t.swap(widen); - } - return t; } -template +template std::enable_if_t, DataTypePtr> getDataTypeByColumnInfoBase(const ColumnInfo & column_info, const T *) { - DataTypePtr t = createDecimal(column_info.flen, column_info.decimal); - - if (should_widen) - { - auto widen = t->widen(); - t.swap(widen); - } - - return t; + return createDecimal(column_info.flen, column_info.decimal); } -template +template std::enable_if_t, DataTypePtr> getDataTypeByColumnInfoBase(const ColumnInfo & column_info, const T *) { // In some cases, TiDB will set the decimal to -1, change -1 to 6 to avoid error - DataTypePtr t = std::make_shared(column_info.decimal == -1 ? 6 : column_info.decimal); - - if (should_widen) - { - auto widen = t->widen(); - t.swap(widen); - } - - return t; + return std::make_shared(column_info.decimal == -1 ? 6 : column_info.decimal); } -template +template std::enable_if_t, DataTypePtr> getDataTypeByColumnInfoBase(const ColumnInfo & column_info, const T *) { - DataTypePtr t = std::make_shared(column_info.elems); - - if (should_widen) - { - auto widen = t->widen(); - t.swap(widen); - } - - return t; + return std::make_shared(column_info.elems); } TypeMapping::TypeMapping() @@ -187,17 +149,24 @@ TypeMapping::TypeMapping() #ifdef M #error "Please undefine macro M first." #endif -#define M(tt, v, cf, ct, w) \ - type_map[TiDB::Type##tt] = std::bind(getDataTypeByColumnInfoBase, std::placeholders::_1, (DataType##ct *)nullptr); +#define M(tt, v, cf, ct) \ + type_map[TiDB::Type##tt] = [](const ColumnInfo & column_info) { \ + return getDataTypeByColumnInfoBase(column_info, (DataType##ct *)nullptr); \ + }; COLUMN_TYPES(M) #undef M } +// Get the basic data type according to column_info. +// This method ignores the nullable flag. DataTypePtr TypeMapping::getDataType(const ColumnInfo & column_info) { return type_map[column_info.tp](column_info); } +// Get the data type according to column_info, respecting +// the nullable flag. +// This does not support the "duration" type. DataTypePtr getDataTypeByColumnInfo(const ColumnInfo & column_info) { DataTypePtr base = TypeMapping::instance().getDataType(column_info); @@ -209,6 +178,8 @@ DataTypePtr getDataTypeByColumnInfo(const ColumnInfo & column_info) return base; } +// Get the data type according to column_info. +// This support the duration type that only will be generated when executing DataTypePtr getDataTypeByColumnInfoForComputingLayer(const ColumnInfo & column_info) { DataTypePtr base = TypeMapping::instance().getDataType(column_info); @@ -252,10 +223,10 @@ void setDecimalPrecScale(const T * decimal_type, ColumnInfo & column_info) void fillTiDBColumnInfo(const String & family_name, const ASTPtr & parameters, ColumnInfo & column_info); void fillTiDBColumnInfo(const ASTPtr & type, ColumnInfo & column_info) { - auto * func = typeid_cast(type.get()); + const auto * func = typeid_cast(type.get()); if (func != nullptr) return fillTiDBColumnInfo(func->name, func->arguments, column_info); - auto * ident = typeid_cast(type.get()); + const auto * ident = typeid_cast(type.get()); if (ident != nullptr) return fillTiDBColumnInfo(ident->name, {}, column_info); throw Exception("Failed to get TiDB data type"); @@ -374,7 +345,7 @@ ColumnInfo reverseGetColumnInfo(const NameAndTypePair & column, ColumnID id, con } else { - auto nullable_type = checkAndGetDataType(nested_type); + const auto * nullable_type = checkAndGetDataType(nested_type); nested_type = nullable_type->getNestedType().get(); } @@ -443,28 +414,28 @@ ColumnInfo reverseGetColumnInfo(const NameAndTypePair & column, ColumnID id, con column_info.setUnsignedFlag(); // Fill flen and decimal for decimal. - if (auto decimal_type32 = checkAndGetDataType>(nested_type)) + if (const auto * decimal_type32 = checkAndGetDataType>(nested_type)) setDecimalPrecScale(decimal_type32, column_info); - else if (auto decimal_type64 = checkAndGetDataType>(nested_type)) + else if (const auto * decimal_type64 = checkAndGetDataType>(nested_type)) setDecimalPrecScale(decimal_type64, column_info); - else if (auto decimal_type128 = checkAndGetDataType>(nested_type)) + else if (const auto * decimal_type128 = checkAndGetDataType>(nested_type)) setDecimalPrecScale(decimal_type128, column_info); - else if (auto decimal_type256 = checkAndGetDataType>(nested_type)) + else if (const auto * decimal_type256 = checkAndGetDataType>(nested_type)) setDecimalPrecScale(decimal_type256, column_info); // Fill decimal for date time. - if (auto type = checkAndGetDataType(nested_type)) + if (const auto * type = checkAndGetDataType(nested_type)) column_info.decimal = type->getFraction(); // Fill decimal for duration. - if (auto type = checkAndGetDataType(nested_type)) + if (const auto * type = checkAndGetDataType(nested_type)) column_info.decimal = type->getFsp(); // Fill elems for enum. if (checkDataType(nested_type)) { - auto enum16_type = checkAndGetDataType(nested_type); - for (auto & element : enum16_type->getValues()) + const auto * enum16_type = checkAndGetDataType(nested_type); + for (const auto & element : enum16_type->getValues()) { column_info.elems.emplace_back(element.first, element.second); } diff --git a/dbms/src/Storages/Transaction/tests/gtest_type_mapping.cpp b/dbms/src/Storages/Transaction/tests/gtest_type_mapping.cpp index 4fda5b51f35..1f5dd7ac2f9 100644 --- a/dbms/src/Storages/Transaction/tests/gtest_type_mapping.cpp +++ b/dbms/src/Storages/Transaction/tests/gtest_type_mapping.cpp @@ -22,12 +22,7 @@ namespace DB namespace tests { -TEST(TypeMapping_test, ColumnInfoToDataType) -{ - // TODO fill this test -} - -TEST(TypeMapping_test, DataTypeToColumnInfo) +TEST(TypeMappingTest, DataTypeToColumnInfo) try { String name = "col"; @@ -67,12 +62,19 @@ try { ASSERT_EQ(column_info.tp, TiDB::TypeLongLong) << actual_test_type; } + + auto data_type = getDataTypeByColumnInfo(column_info); + ASSERT_EQ(data_type->getName(), actual_test_type); } } } column_info = reverseGetColumnInfo(NameAndTypePair{name, typeFromString("String")}, 1, default_field, true); ASSERT_EQ(column_info.tp, TiDB::TypeString); + auto data_type = getDataTypeByColumnInfo(column_info); + ASSERT_EQ(data_type->getName(), "String"); + + // TODO: test decimal, datetime, enum } CATCH diff --git a/dbms/src/TestUtils/ColumnsToTiPBExpr.cpp b/dbms/src/TestUtils/ColumnsToTiPBExpr.cpp index af8c8bed4ba..2c993ec91ea 100644 --- a/dbms/src/TestUtils/ColumnsToTiPBExpr.cpp +++ b/dbms/src/TestUtils/ColumnsToTiPBExpr.cpp @@ -13,13 +13,14 @@ // limitations under the License. #include +#include #include #include #include +#include #include #include - namespace DB { namespace tests diff --git a/dbms/src/TestUtils/ExecutorTestUtils.cpp b/dbms/src/TestUtils/ExecutorTestUtils.cpp index 6becfe3d74d..7ca140ee650 100644 --- a/dbms/src/TestUtils/ExecutorTestUtils.cpp +++ b/dbms/src/TestUtils/ExecutorTestUtils.cpp @@ -194,12 +194,19 @@ void ExecutorTest::enablePlanner(bool is_enable) context.context.setSetting("enable_planner", is_enable ? "true" : "false"); } -DB::ColumnsWithTypeAndName ExecutorTest::executeStreams(const std::shared_ptr & request, size_t concurrency) +DB::ColumnsWithTypeAndName ExecutorTest::executeStreams( + const std::shared_ptr & request, + size_t concurrency) { DAGContext dag_context(*request, "executor_test", concurrency); + return executeStreams(&dag_context); +} + +ColumnsWithTypeAndName ExecutorTest::executeStreams(DAGContext * dag_context) +{ context.context.setExecutorTest(); context.context.setMockStorage(context.mockStorage()); - context.context.setDAGContext(&dag_context); + context.context.setDAGContext(dag_context); // Currently, don't care about regions information in tests. Blocks blocks; queryExecute(context.context, /*internal=*/true)->execute([&blocks](const Block & block) { blocks.push_back(block); }).verify(); diff --git a/dbms/src/TestUtils/ExecutorTestUtils.h b/dbms/src/TestUtils/ExecutorTestUtils.h index fe00d14608b..79c279f2822 100644 --- a/dbms/src/TestUtils/ExecutorTestUtils.h +++ b/dbms/src/TestUtils/ExecutorTestUtils.h @@ -92,6 +92,8 @@ class ExecutorTest : public ::testing::Test } } + ColumnsWithTypeAndName executeStreams(DAGContext * dag_context); + ColumnsWithTypeAndName executeStreams( const std::shared_ptr & request, size_t concurrency = 1); diff --git a/dbms/src/TestUtils/MPPTaskTestUtils.cpp b/dbms/src/TestUtils/MPPTaskTestUtils.cpp index 143a9a78034..b187f3e6f5a 100644 --- a/dbms/src/TestUtils/MPPTaskTestUtils.cpp +++ b/dbms/src/TestUtils/MPPTaskTestUtils.cpp @@ -11,7 +11,12 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include +#include +#include +#include #include +#include #include namespace DB::tests @@ -80,7 +85,7 @@ std::tuple> MPPTaskTestUtils::prepa return {MPPQueryId(properties.query_ts, properties.local_query_id, properties.server_id, properties.start_ts), res}; } -ColumnsWithTypeAndName MPPTaskTestUtils::exeucteMPPTasks(QueryTasks & tasks, const DAGProperties & properties, std::unordered_map & server_config_map) +ColumnsWithTypeAndName MPPTaskTestUtils::executeMPPTasks(QueryTasks & tasks, const DAGProperties & properties, std::unordered_map & server_config_map) { auto res = executeMPPQueryWithMultipleContext(properties, tasks, server_config_map); return readBlocks(res); @@ -145,8 +150,8 @@ ::testing::AssertionResult MPPTaskTestUtils::assertQueryCancelled(const MPPQuery { std::this_thread::sleep_for(seconds); retry_times++; - // Currenly we wait for 10 times to ensure all tasks are cancelled. - if (retry_times > 10) + // Currenly we wait for 20 times to ensure all tasks are cancelled. + if (retry_times > 20) { return ::testing::AssertionFailure() << "Query not cancelled, " << queryInfo(i) << std::endl; } @@ -166,4 +171,15 @@ ::testing::AssertionResult MPPTaskTestUtils::assertQueryActive(const MPPQueryId } return ::testing::AssertionSuccess(); } + +ColumnsWithTypeAndName MPPTaskTestUtils::buildAndExecuteMPPTasks(DAGRequestBuilder builder) +{ + auto properties = DB::tests::getDAGPropertiesForTest(serverNum()); + for (int i = 0; i < TiFlashTestEnv::globalContextSize(); ++i) + TiFlashTestEnv::getGlobalContext(i).setMPPTest(); + auto tasks = (builder).buildMPPTasks(context, properties); + MockComputeServerManager::instance().resetMockMPPServerInfo(serverNum()); + MockComputeServerManager::instance().setMockStorage(context.mockStorage()); + return executeMPPTasks(tasks, properties, MockComputeServerManager::instance().getServerConfigMap()); +} } // namespace DB::tests diff --git a/dbms/src/TestUtils/MPPTaskTestUtils.h b/dbms/src/TestUtils/MPPTaskTestUtils.h index cb0e84a2a14..75330ed0c6d 100644 --- a/dbms/src/TestUtils/MPPTaskTestUtils.h +++ b/dbms/src/TestUtils/MPPTaskTestUtils.h @@ -15,11 +15,7 @@ #pragma once #include -#include -#include -#include #include -#include namespace DB::tests { @@ -50,12 +46,12 @@ class MockServerAddrGenerator : public ext::Singleton void reset() { - port = 3931; + port = 4931; } private: const Int64 port_upper_bound = 65536; - std::atomic port = 3931; + std::atomic port = 4931; }; // Hold MPP test related infomation: @@ -82,12 +78,14 @@ class MPPTaskTestUtils : public ExecutorTest // run mpp tasks which are ready to cancel, the return value is the start_ts of query. std::tuple> prepareMPPStreams(DAGRequestBuilder builder); - ColumnsWithTypeAndName exeucteMPPTasks(QueryTasks & tasks, const DAGProperties & properties, std::unordered_map & server_config_map); + static ColumnsWithTypeAndName executeMPPTasks(QueryTasks & tasks, const DAGProperties & properties, std::unordered_map & server_config_map); + ColumnsWithTypeAndName buildAndExecuteMPPTasks(DAGRequestBuilder builder); ColumnsWithTypeAndName executeCoprocessorTask(std::shared_ptr & dag_request); static ::testing::AssertionResult assertQueryCancelled(const MPPQueryId & query_id); static ::testing::AssertionResult assertQueryActive(const MPPQueryId & query_id); + static String queryInfo(size_t server_id); protected: @@ -101,7 +99,7 @@ class MPPTaskTestUtils : public ExecutorTest { \ TiFlashTestEnv::getGlobalContext().setMPPTest(); \ MockComputeServerManager::instance().setMockStorage(context.mockStorage()); \ - ASSERT_COLUMNS_EQ_UR(exeucteMPPTasks(tasks, properties, MockComputeServerManager::instance().getServerConfigMap()), expected_cols); \ + ASSERT_COLUMNS_EQ_UR(expected_cols, executeMPPTasks(tasks, properties, MockComputeServerManager::instance().getServerConfigMap())); \ } while (0) diff --git a/dbms/src/TestUtils/mockExecutor.cpp b/dbms/src/TestUtils/mockExecutor.cpp index e6c9a82a231..66fabb59fbd 100644 --- a/dbms/src/TestUtils/mockExecutor.cpp +++ b/dbms/src/TestUtils/mockExecutor.cpp @@ -13,6 +13,17 @@ // limitations under the License. #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include #include @@ -261,22 +272,23 @@ DAGRequestBuilder & DAGRequestBuilder::exchangeSender(tipb::ExchangeType exchang return *this; } -DAGRequestBuilder & DAGRequestBuilder::join(const DAGRequestBuilder & right, - tipb::JoinType tp, - MockAstVec join_cols, - MockAstVec left_conds, - MockAstVec right_conds, - MockAstVec other_conds, - MockAstVec other_eq_conds_from_in) +DAGRequestBuilder & DAGRequestBuilder::join( + const DAGRequestBuilder & right, + tipb::JoinType tp, + MockAstVec join_col_exprs, + MockAstVec left_conds, + MockAstVec right_conds, + MockAstVec other_conds, + MockAstVec other_eq_conds_from_in, + uint64_t fine_grained_shuffle_stream_count) { assert(root); assert(right.root); - - root = mock::compileJoin(getExecutorIndex(), root, right.root, tp, join_cols, left_conds, right_conds, other_conds, other_eq_conds_from_in); + root = mock::compileJoin(getExecutorIndex(), root, right.root, tp, join_col_exprs, left_conds, right_conds, other_conds, other_eq_conds_from_in, fine_grained_shuffle_stream_count); return *this; } -DAGRequestBuilder & DAGRequestBuilder::aggregation(ASTPtr agg_func, ASTPtr group_by_expr) +DAGRequestBuilder & DAGRequestBuilder::aggregation(ASTPtr agg_func, ASTPtr group_by_expr, uint64_t fine_grained_shuffle_stream_count) { auto agg_funcs = std::make_shared(); auto group_by_exprs = std::make_shared(); @@ -284,10 +296,10 @@ DAGRequestBuilder & DAGRequestBuilder::aggregation(ASTPtr agg_func, ASTPtr group agg_funcs->children.push_back(agg_func); if (group_by_expr) group_by_exprs->children.push_back(group_by_expr); - return buildAggregation(agg_funcs, group_by_exprs); + return buildAggregation(agg_funcs, group_by_exprs, fine_grained_shuffle_stream_count); } -DAGRequestBuilder & DAGRequestBuilder::aggregation(MockAstVec agg_funcs, MockAstVec group_by_exprs) +DAGRequestBuilder & DAGRequestBuilder::aggregation(MockAstVec agg_funcs, MockAstVec group_by_exprs, uint64_t fine_grained_shuffle_stream_count) { auto agg_func_list = std::make_shared(); auto group_by_expr_list = std::make_shared(); @@ -295,13 +307,13 @@ DAGRequestBuilder & DAGRequestBuilder::aggregation(MockAstVec agg_funcs, MockAst agg_func_list->children.push_back(func); for (const auto & group_by : group_by_exprs) group_by_expr_list->children.push_back(group_by); - return buildAggregation(agg_func_list, group_by_expr_list); + return buildAggregation(agg_func_list, group_by_expr_list, fine_grained_shuffle_stream_count); } -DAGRequestBuilder & DAGRequestBuilder::buildAggregation(ASTPtr agg_funcs, ASTPtr group_by_exprs) +DAGRequestBuilder & DAGRequestBuilder::buildAggregation(ASTPtr agg_funcs, ASTPtr group_by_exprs, uint64_t fine_grained_shuffle_stream_count) { assert(root); - root = compileAggregation(root, getExecutorIndex(), agg_funcs, group_by_exprs); + root = compileAggregation(root, getExecutorIndex(), agg_funcs, group_by_exprs, fine_grained_shuffle_stream_count); return *this; } @@ -385,6 +397,7 @@ void MockDAGRequestContext::addMockTable(const String & db, const String & table void MockDAGRequestContext::addMockTable(const MockTableName & name, const MockColumnInfoVec & columnInfos, ColumnsWithTypeAndName columns) { + assert(columnInfos.size() == columns.size()); addMockTable(name, columnInfos); addMockTableColumnData(name, columns); } diff --git a/dbms/src/TestUtils/mockExecutor.h b/dbms/src/TestUtils/mockExecutor.h index 11c09caf4cf..8c9b2697ee3 100644 --- a/dbms/src/TestUtils/mockExecutor.h +++ b/dbms/src/TestUtils/mockExecutor.h @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -121,16 +122,17 @@ class DAGRequestBuilder /// @param right_conds conditional expressions which only reference right table and the join type is right kind /// @param other_conds other conditional expressions /// @param other_eq_conds_from_in equality expressions within in subquery whose join type should be AntiSemiJoin, AntiLeftOuterSemiJoin or LeftOuterSemiJoin - DAGRequestBuilder & join(const DAGRequestBuilder & right, tipb::JoinType tp, MockAstVec join_col_exprs, MockAstVec left_conds, MockAstVec right_conds, MockAstVec other_conds, MockAstVec other_eq_conds_from_in); - DAGRequestBuilder & join(const DAGRequestBuilder & right, tipb::JoinType tp, MockAstVec join_col_exprs) + /// @param fine_grained_shuffle_stream_count decide the generated tipb executor's find_grained_shuffle_stream_count + DAGRequestBuilder & join(const DAGRequestBuilder & right, tipb::JoinType tp, MockAstVec join_col_exprs, MockAstVec left_conds, MockAstVec right_conds, MockAstVec other_conds, MockAstVec other_eq_conds_from_in, uint64_t fine_grained_shuffle_stream_count = 0); + DAGRequestBuilder & join(const DAGRequestBuilder & right, tipb::JoinType tp, MockAstVec join_col_exprs, uint64_t fine_grained_shuffle_stream_count = 0) { - return join(right, tp, join_col_exprs, {}, {}, {}, {}); + return join(right, tp, join_col_exprs, {}, {}, {}, {}, fine_grained_shuffle_stream_count); } // aggregation - DAGRequestBuilder & aggregation(ASTPtr agg_func, ASTPtr group_by_expr); - DAGRequestBuilder & aggregation(MockAstVec agg_funcs, MockAstVec group_by_exprs); + DAGRequestBuilder & aggregation(ASTPtr agg_func, ASTPtr group_by_expr, uint64_t fine_grained_shuffle_stream_count = 0); + DAGRequestBuilder & aggregation(MockAstVec agg_funcs, MockAstVec group_by_exprs, uint64_t fine_grained_shuffle_stream_count = 0); // window DAGRequestBuilder & window(ASTPtr window_func, MockOrderByItem order_by, MockPartitionByItem partition_by, MockWindowFrame frame, uint64_t fine_grained_shuffle_stream_count = 0); @@ -144,7 +146,7 @@ class DAGRequestBuilder private: void initDAGRequest(tipb::DAGRequest & dag_request); - DAGRequestBuilder & buildAggregation(ASTPtr agg_funcs, ASTPtr group_by_exprs); + DAGRequestBuilder & buildAggregation(ASTPtr agg_funcs, ASTPtr group_by_exprs, uint64_t fine_grained_shuffle_stream_count = 0); DAGRequestBuilder & buildExchangeReceiver(const MockColumnInfoVec & columns, uint64_t fine_grained_shuffle_stream_count = 0); mock::ExecutorBinderPtr root; diff --git a/docs/design/2022-12-21-auto-reload-tls-certificate.md b/docs/design/2022-12-21-auto-reload-tls-certificate.md new file mode 100644 index 00000000000..e96705cb7ae --- /dev/null +++ b/docs/design/2022-12-21-auto-reload-tls-certificate.md @@ -0,0 +1,75 @@ +# Auto reload TLS certificate for TiFlash + +- Author: [Weiqi Yan](https://github.com/ywqzzy) + +## Introduction + +In TiFlash config, we can set certificate as follows: +```YAML +[security] +ca_path = "/path/to/tls/ca.crt" +cert_path = "/path/to/tls/tiflash.crt" +key_path = "/path/to/tls/tiflash.pem" +``` +Then the TiFlash Server can use TLS certificates to enable secure transmission. + +Since the TLS certificate has a valid period, in order not to affect the normal operation of online business, the TiFlash node should not be restarted manually when replacing the certificate, so it needs to support automatic rotation of TLS certificates. + +By modifying the TiFlash configuration file(tiflash.toml) or the certificate content, TiFlash can dynamically load new TLS certificates. + +## Desgin + +### Overview + +TiFlash uses TLS certificates in GRPC Server, TCP Server, HTTP Server, MetricsPrometheus HTTP Server, client-c (rpcClient, pdClient). + +There are two ways to modify the certificate: + +1. Modify the path of the certificate in the config file. + +2. Directly overwrite the content of the certificate at the specified path. + +For 1, we use `main_config_reloader` to monitor the change of the certificate file path, then update the certificate path information in `Context` after the change of the file path, then update the certificate path maintained in the `TiFlashSecurity` class. + +For 2, the certificate can be dynamically loaded in the form of callback for various servers, and the client connection can be rebuilt for the client. + +### Detailed Design + +#### GRPC Server + +The certificate used by GRPC server should change with the change of the certificate path in `TiFlashSecurity`. +At the same time, each time a new SSL connection is created, a new certificate can be loaded according to the certificate path. + +We need to set one `ConfigFetcher `when building GRPC server in order to dynamically read new certificates when establishing a new SSL connection. + +The `ConfigFetcher` will be set as follows: + +```C++ +builder.AddListeningPort( + raft_config.flash_server_addr, + sslServerCredentialsWithFetcher(context)); +``` + +The `sslServerCredentialsWithFetcher` method will set the `ConfigFetcher` for `grpc::ServerBuilder`. As a callback function, `ConfigFetcher` obtains the certificate path from TiFlashSecurity and sets the certificate for each SSL connection. + +#### HTTP/TCP Server/MetricPrometheus + +These server use `Poco::Net`. To reload certificate dynamically, we can call `SSL_CTX_set_cert_cb` for `Poco::Net::Context::sslContext`. Then these servers can dynamically read new certificates and set new SSL certs when establishing SSL connections with others. + +The setting process is as follows: + +```C++ +SSL_CTX_set_cert_cb(context->sslContext(), + callSetCertificate, + reinterpret_cast(global_context)); +``` + +`callSetCertificate` will read the certificate path from `TiFlashSecurity`, then read the new certificate from certificate path in order to set the new certificate. + +#### client-c + +Judge whether the certificate file has changed (or whether the certificate path has changed) in `main_config_reloader`. When the certificate has changed, read `TiFlashSecurity` to get new certificate paths, and clear the existing client conn array (including `pdClient` and `rpcClient`), so that the new certificate can be read later to create a new connection. + +#### ConfigReloader + +The `ConfigReloader` should monitor whether the certificate file changed or the certificate file paths changed. When changes occur, the `ConfigReloader` should call `reload` to reload some of the configs that need to refresh.