Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

function result name should contain collator info (#2808) #2818

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 103 additions & 35 deletions dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@

namespace DB
{

namespace ErrorCodes
{
extern const int COP_BAD_DAG_REQUEST;
extern const int UNSUPPORTED_METHOD;
} // namespace ErrorCodes

static String genFuncString(const String & func_name, const Names & argument_names)
static String genFuncString(const String & func_name, const Names & argument_names, const TiDB::TiDBCollators & collators)
{
assert(!collators.empty());
std::stringstream ss;
ss << func_name << "(";
bool first = true;
Expand All @@ -45,7 +45,15 @@ static String genFuncString(const String & func_name, const Names & argument_nam
}
ss << argument_name;
}
ss << ") ";
ss << ")_collator";
for (const auto & collator : collators)
{
if (collator == nullptr)
ss << "_0";
else
ss << "_" << collator->getCollatorId();
}
ss << " ";
return ss.str();
}

Expand Down Expand Up @@ -206,10 +214,14 @@ static String buildLeftUTF8Function(DAGExpressionAnalyzer * analyzer, const tipb

static const String tidb_cast_name = "tidb_cast";

static String buildCastFunctionInternal(DAGExpressionAnalyzer * analyzer, const Names & argument_names, bool in_union,
const tipb::FieldType & field_type, ExpressionActionsPtr & actions)
static String buildCastFunctionInternal(
DAGExpressionAnalyzer * analyzer,
const Names & argument_names,
bool in_union,
const tipb::FieldType & field_type,
ExpressionActionsPtr & actions)
{
String result_name = genFuncString(tidb_cast_name, argument_names);
String result_name = genFuncString(tidb_cast_name, argument_names, {nullptr});
if (actions->getSampleBlock().has(result_name))
return result_name;

Expand Down Expand Up @@ -247,16 +259,29 @@ struct DateAdd
static constexpr auto name = "date_add";
static const std::unordered_map<String, String> unit_to_func_name_map;
};
const std::unordered_map<String, String> DateAdd::unit_to_func_name_map = {{"DAY", "addDays"}, {"WEEK", "addWeeks"}, {"MONTH", "addMonths"},
{"YEAR", "addYears"}, {"HOUR", "addHours"}, {"MINUTE", "addMinutes"}, {"SECOND", "addSeconds"}};
const std::unordered_map<String, String> DateAdd::unit_to_func_name_map
= {
{"DAY", "addDays"},
{"WEEK", "addWeeks"},
{"MONTH", "addMonths"},
{"YEAR", "addYears"},
{"HOUR", "addHours"},
{"MINUTE", "addMinutes"},
{"SECOND", "addSeconds"}};
struct DateSub
{
static constexpr auto name = "date_sub";
static const std::unordered_map<String, String> unit_to_func_name_map;
};
const std::unordered_map<String, String> DateSub::unit_to_func_name_map
= {{"DAY", "subtractDays"}, {"WEEK", "subtractWeeks"}, {"MONTH", "subtractMonths"}, {"YEAR", "subtractYears"},
{"HOUR", "subtractHours"}, {"MINUTE", "subtractMinutes"}, {"SECOND", "subtractSeconds"}};
= {
{"DAY", "subtractDays"},
{"WEEK", "subtractWeeks"},
{"MONTH", "subtractMonths"},
{"YEAR", "subtractYears"},
{"HOUR", "subtractHours"},
{"MINUTE", "subtractMinutes"},
{"SECOND", "subtractSeconds"}};

template <typename Impl>
static String buildDateAddOrSubFunction(DAGExpressionAnalyzer * analyzer, const tipb::Expr & expr, ExpressionActionsPtr & actions)
Expand All @@ -270,12 +295,14 @@ static String buildDateAddOrSubFunction(DAGExpressionAnalyzer * analyzer, const
if (expr.children(2).tp() != tipb::ExprType::String)
{
throw TiFlashException(
std::string() + "3rd argument of " + Impl::name + " function must be string literal", Errors::Coprocessor::BadRequest);
std::string() + "3rd argument of " + Impl::name + " function must be string literal",
Errors::Coprocessor::BadRequest);
}
String unit = expr.children(2).val();
if (Impl::unit_to_func_name_map.find(unit) == Impl::unit_to_func_name_map.end())
throw TiFlashException(
std::string() + Impl::name + " function does not support unit " + unit + " yet.", Errors::Coprocessor::Unimplemented);
std::string() + Impl::name + " function does not support unit " + unit + " yet.",
Errors::Coprocessor::Unimplemented);
String func_name = Impl::unit_to_func_name_map.find(unit)->second;
const auto & date_column_type = removeNullable(actions->getSampleBlock().getByName(date_column).type);
if (!date_column_type->isDateOrDateTime())
Expand Down Expand Up @@ -350,21 +377,32 @@ static std::unordered_map<String, std::function<String(DAGExpressionAnalyzer *,
{"date_add", buildDateAddOrSubFunction<DateAdd>}, {"date_sub", buildDateAddOrSubFunction<DateSub>}});

DAGExpressionAnalyzer::DAGExpressionAnalyzer(std::vector<NameAndTypePair> && source_columns_, const Context & context_)
: source_columns(std::move(source_columns_)), context(context_), after_agg(false), implicit_cast_count(0)
: source_columns(std::move(source_columns_))
, context(context_)
, after_agg(false)
, implicit_cast_count(0)
{
settings = context.getSettings();
}

DAGExpressionAnalyzer::DAGExpressionAnalyzer(std::vector<NameAndTypePair> & source_columns_, const Context & context_)
: source_columns(source_columns_), context(context_), after_agg(false), implicit_cast_count(0)
: source_columns(source_columns_)
, context(context_)
, after_agg(false)
, implicit_cast_count(0)
{
settings = context.getSettings();
}

extern const String CountSecondStage;

void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, const tipb::Aggregation & agg, Names & aggregation_keys,
TiDB::TiDBCollators & collators, AggregateDescriptions & aggregate_descriptions, bool group_by_collation_sensitive)
void DAGExpressionAnalyzer::appendAggregation(
ExpressionActionsChain & chain,
const tipb::Aggregation & agg,
Names & aggregation_keys,
TiDB::TiDBCollators & collators,
AggregateDescriptions & aggregate_descriptions,
bool group_by_collation_sensitive)
{
if (agg.group_by_size() == 0 && agg.agg_func_size() == 0)
{
Expand Down Expand Up @@ -401,7 +439,8 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co
aggregate.argument_names[i] = arg_name;
step.required_output.push_back(arg_name);
}
String func_string = genFuncString(agg_func_name, aggregate.argument_names);
auto function_collator = getCollatorFromExpr(expr);
String func_string = genFuncString(agg_func_name, aggregate.argument_names, {function_collator});
bool duplicate = false;
for (const auto & pre_agg : aggregate_descriptions)
{
Expand All @@ -418,7 +457,7 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co
aggregate.parameters = Array();
/// if there is group by clause, there is no need to consider the empty input case
aggregate.function = AggregateFunctionFactory::instance().get(agg_func_name, types, {}, 0, agg.group_by_size() == 0);
aggregate.function->setCollator(getCollatorFromExpr(expr));
aggregate.function->setCollator(function_collator);
aggregate_descriptions.push_back(aggregate);
DataTypePtr result_type = aggregate.function->getReturnType();
// this is a temp result since implicit cast maybe added on these aggregated_columns
Expand Down Expand Up @@ -463,7 +502,7 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co
types[0] = type;
aggregate.argument_names[0] = name;

String func_string = genFuncString(agg_func_name, aggregate.argument_names);
String func_string = genFuncString(agg_func_name, aggregate.argument_names, {collator});
bool duplicate = false;
for (const auto & pre_agg : aggregate_descriptions)
{
Expand Down Expand Up @@ -508,7 +547,7 @@ bool isUInt8Type(const DataTypePtr & type)
String DAGExpressionAnalyzer::applyFunction(
const String & func_name, const Names & arg_names, ExpressionActionsPtr & actions, std::shared_ptr<TiDB::ITiDBCollator> collator)
{
String result_name = genFuncString(func_name, arg_names);
String result_name = genFuncString(func_name, arg_names, {collator});
if (actions->getSampleBlock().has(result_name))
return result_name;
const FunctionBuilderPtr & function_builder = FunctionFactory::instance().get(func_name, context);
Expand All @@ -518,7 +557,9 @@ String DAGExpressionAnalyzer::applyFunction(
}

void DAGExpressionAnalyzer::appendWhere(
ExpressionActionsChain & chain, const std::vector<const tipb::Expr *> & conditions, String & filter_column_name)
ExpressionActionsChain & chain,
const std::vector<const tipb::Expr *> & conditions,
String & filter_column_name)
{
initChain(chain, getCurrentInputColumns());
ExpressionActionsChain::Step & last_step = chain.steps.back();
Expand Down Expand Up @@ -605,7 +646,9 @@ String DAGExpressionAnalyzer::convertToUInt8(ExpressionActionsPtr & actions, con
}

void DAGExpressionAnalyzer::appendOrderBy(
ExpressionActionsChain & chain, const tipb::TopN & topN, std::vector<NameAndTypePair> & order_columns)
ExpressionActionsChain & chain,
const tipb::TopN & topN,
std::vector<NameAndTypePair> & order_columns)
{
if (topN.order_by_size() == 0)
{
Expand Down Expand Up @@ -645,7 +688,10 @@ void constructTZExpr(tipb::Expr & tz_expr, const TimezoneInfo & dag_timezone_inf
}

String DAGExpressionAnalyzer::appendTimeZoneCast(
const String & tz_col, const String & ts_col, const String & func_name, ExpressionActionsPtr & actions)
const String & tz_col,
const String & ts_col,
const String & func_name,
ExpressionActionsPtr & actions)
{
String cast_expr_name = applyFunction(func_name, {ts_col, tz_col}, actions, nullptr);
return cast_expr_name;
Expand Down Expand Up @@ -688,16 +734,24 @@ bool DAGExpressionAnalyzer::appendTimeZoneCastsAfterTS(ExpressionActionsChain &
}

void DAGExpressionAnalyzer::appendJoin(
ExpressionActionsChain & chain, SubqueryForSet & join_query, const NamesAndTypesList & columns_added_by_join)
ExpressionActionsChain & chain,
SubqueryForSet & join_query,
const NamesAndTypesList & columns_added_by_join)
{
initChain(chain, getCurrentInputColumns());
ExpressionActionsPtr actions = chain.getLastActions();
actions->add(ExpressionAction::ordinaryJoin(join_query.join, columns_added_by_join));
}
/// return true if some actions is needed
bool DAGExpressionAnalyzer::appendJoinKeyAndJoinFilters(ExpressionActionsChain & chain,
const google::protobuf::RepeatedPtrField<tipb::Expr> & keys, const DataTypes & key_types, Names & key_names, bool left,
bool is_right_out_join, const google::protobuf::RepeatedPtrField<tipb::Expr> & filters, String & filter_column_name)
bool DAGExpressionAnalyzer::appendJoinKeyAndJoinFilters(
ExpressionActionsChain & chain,
const google::protobuf::RepeatedPtrField<tipb::Expr> & keys,
const DataTypes & key_types,
Names & key_names,
bool left,
bool is_right_out_join,
const google::protobuf::RepeatedPtrField<tipb::Expr> & filters,
String & filter_column_name)
{
bool ret = false;
initChain(chain, getCurrentInputColumns());
Expand Down Expand Up @@ -849,8 +903,12 @@ void DAGExpressionAnalyzer::appendAggSelect(ExpressionActionsChain & chain, cons
}
}

void DAGExpressionAnalyzer::generateFinalProject(ExpressionActionsChain & chain, const std::vector<tipb::FieldType> & schema,
const std::vector<Int32> & output_offsets, const String & column_prefix, bool keep_session_timezone_info,
void DAGExpressionAnalyzer::generateFinalProject(
ExpressionActionsChain & chain,
const std::vector<tipb::FieldType> & schema,
const std::vector<Int32> & output_offsets,
const String & column_prefix,
bool keep_session_timezone_info,
NamesWithAliases & final_project)
{
if (unlikely(!keep_session_timezone_info && output_offsets.empty()))
Expand Down Expand Up @@ -888,7 +946,8 @@ void DAGExpressionAnalyzer::generateFinalProject(ExpressionActionsChain & chain,
for (auto i : output_offsets)
{
final_project.emplace_back(
current_columns[i].name, unique_name_generator.toUniqueName(column_prefix + current_columns[i].name));
current_columns[i].name,
unique_name_generator.toUniqueName(column_prefix + current_columns[i].name));
}
}
else
Expand Down Expand Up @@ -945,7 +1004,8 @@ void DAGExpressionAnalyzer::generateFinalProject(ExpressionActionsChain & chain,
else
{
final_project.emplace_back(
current_columns[i].name, unique_name_generator.toUniqueName(column_prefix + current_columns[i].name));
current_columns[i].name,
unique_name_generator.toUniqueName(column_prefix + current_columns[i].name));
}
}
}
Expand All @@ -961,7 +1021,10 @@ void DAGExpressionAnalyzer::generateFinalProject(ExpressionActionsChain & chain,
* @return
*/
String DAGExpressionAnalyzer::alignReturnType(
const tipb::Expr & expr, ExpressionActionsPtr & actions, const String & expr_name, bool force_uint8)
const tipb::Expr & expr,
ExpressionActionsPtr & actions,
const String & expr_name,
bool force_uint8)
{
DataTypePtr orig_type = actions->getSampleBlock().getByName(expr_name).type;
if (force_uint8 && isUInt8Type(orig_type))
Expand All @@ -985,7 +1048,10 @@ String DAGExpressionAnalyzer::appendCast(const DataTypePtr & target_type, Expres
}

String DAGExpressionAnalyzer::appendCastIfNeeded(
const tipb::Expr & expr, ExpressionActionsPtr & actions, const String & expr_name, bool explicit_cast)
const tipb::Expr & expr,
ExpressionActionsPtr & actions,
const String & expr_name,
bool explicit_cast)
{
if (!isFunctionExpr(expr))
return expr_name;
Expand All @@ -1000,7 +1066,6 @@ String DAGExpressionAnalyzer::appendCastIfNeeded(
DataTypePtr actual_type = actions->getSampleBlock().getByName(expr_name).type;
if (expected_type->getName() != actual_type->getName())
{

implicit_cast_count += !explicit_cast;
return appendCast(expected_type, actions, expr_name);
}
Expand All @@ -1013,7 +1078,10 @@ String DAGExpressionAnalyzer::appendCastIfNeeded(
}

void DAGExpressionAnalyzer::makeExplicitSet(
const tipb::Expr & expr, const Block & sample_block, bool create_ordered_set, const String & left_arg_name)
const tipb::Expr & expr,
const Block & sample_block,
bool create_ordered_set,
const String & left_arg_name)
{
if (prepared_sets.count(&expr))
{
Expand Down
15 changes: 15 additions & 0 deletions tests/tidb-ci/new_collation_fullstack/function_collator.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
mysql> drop table if exists test.t1
mysql> drop table if exists test.t2
mysql> create table test.t1(col_varchar_20_key_signed varchar(20) COLLATE utf8mb4_general_ci, col_varbinary_20_key_signed varbinary(20), col_varbinary_20_undef_signed varbinary(20));
mysql> create table test.t2(col_char_20_key_signed char(20) COLLATE utf8mb4_general_ci, col_varchar_20_undef_signed varchar(20) COLLATE utf8mb4_general_ci);
mysql> alter table test.t1 set tiflash replica 1
mysql> alter table test.t2 set tiflash replica 1
mysql> insert into test.t1 values('Abc',0x62,0x616263);
mysql> insert into test.t2 values('abc','b');
func> wait_table test t1
func> wait_table test t2

mysql> set session tidb_enforce_mpp=1; select * from test.t1 where t1.col_varchar_20_key_signed not in (select col_char_20_key_signed from test.t2 where t1.col_varchar_20_key_signed not in ( t1.col_varbinary_20_key_signed, t1.col_varbinary_20_undef_signed,col_varchar_20_undef_signed,col_char_20_key_signed));

mysql> drop table if exists test.t1;
mysql> drop table if exists test.t2;