Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

function result name should contain collator info (#2808) #3018

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 55 additions & 21 deletions dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@

namespace DB
{

namespace ErrorCodes
{
extern const int COP_BAD_DAG_REQUEST;
extern const int UNSUPPORTED_METHOD;
} // namespace ErrorCodes

static String genFuncString(const String & func_name, const Names & argument_names)
static String genFuncString(const String & func_name, const Names & argument_names, const TiDB::TiDBCollators & collators)
{
assert(!collators.empty());
std::stringstream ss;
ss << func_name << "(";
bool first = true;
Expand All @@ -46,7 +46,15 @@ static String genFuncString(const String & func_name, const Names & argument_nam
}
ss << argument_name;
}
ss << ") ";
ss << ")_collator";
for (const auto & collator : collators)
{
if (collator == nullptr)
ss << "_0";
else
ss << "_" << collator->getCollatorId();
}
ss << " ";
return ss.str();
}

Expand Down Expand Up @@ -162,10 +170,14 @@ static String buildLogicalFunction(DAGExpressionAnalyzer * analyzer, const tipb:

static const String tidb_cast_name = "tidb_cast";

static String buildCastFunctionInternal(DAGExpressionAnalyzer * analyzer, const Names & argument_names, bool in_union,
const tipb::FieldType & field_type, ExpressionActionsPtr & actions)
static String buildCastFunctionInternal(
DAGExpressionAnalyzer * analyzer,
const Names & argument_names,
bool in_union,
const tipb::FieldType & field_type,
ExpressionActionsPtr & actions)
{
String result_name = genFuncString(tidb_cast_name, argument_names);
String result_name = genFuncString(tidb_cast_name, argument_names, {nullptr});
if (actions->getSampleBlock().has(result_name))
return result_name;

Expand Down Expand Up @@ -200,7 +212,6 @@ static String buildCastFunction(DAGExpressionAnalyzer * analyzer, const tipb::Ex

static String buildDateAddFunction(DAGExpressionAnalyzer * analyzer, const tipb::Expr & expr, ExpressionActionsPtr & actions)
{

static const std::unordered_map<String, String> unit_to_func_name_map({{"DAY", "addDays"}, {"WEEK", "addWeeks"}, {"MONTH", "addMonths"},
{"YEAR", "addYears"}, {"HOUR", "addHours"}, {"MINUTE", "addMinutes"}, {"SECOND", "addSeconds"}});
if (expr.children_size() != 3)
Expand Down Expand Up @@ -303,7 +314,10 @@ static std::unordered_map<String, std::function<String(DAGExpressionAnalyzer *,
});

DAGExpressionAnalyzer::DAGExpressionAnalyzer(std::vector<NameAndTypePair> && source_columns_, const Context & context_)
: source_columns(std::move(source_columns_)), context(context_), after_agg(false), implicit_cast_count(0)
: source_columns(std::move(source_columns_))
, context(context_)
, after_agg(false)
, implicit_cast_count(0)
{
settings = context.getSettings();
}
Expand Down Expand Up @@ -332,7 +346,8 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co
aggregate.argument_names[i] = arg_name;
step.required_output.push_back(arg_name);
}
String func_string = genFuncString(agg_func_name, aggregate.argument_names);
auto function_collator = getCollatorFromExpr(expr);
String func_string = genFuncString(agg_func_name, aggregate.argument_names, {function_collator});
bool duplicate = false;
for (const auto & pre_agg : aggregate_descriptions)
{
Expand All @@ -349,7 +364,7 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co
aggregate.parameters = Array();
/// if there is group by clause, there is no need to consider the empty input case
aggregate.function = AggregateFunctionFactory::instance().get(agg_func_name, types, {}, 0, agg.group_by_size() == 0);
aggregate.function->setCollator(getCollatorFromExpr(expr));
aggregate.function->setCollator(function_collator);
aggregate_descriptions.push_back(aggregate);
DataTypePtr result_type = aggregate.function->getReturnType();
// this is a temp result since implicit cast maybe added on these aggregated_columns
Expand Down Expand Up @@ -386,7 +401,8 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co
types[0] = type;
aggregate.argument_names[0] = name;

String func_string = genFuncString(agg_func_name, aggregate.argument_names);
auto function_collator = getCollatorFromExpr(expr);
String func_string = genFuncString(agg_func_name, aggregate.argument_names, {function_collator});
bool duplicate = false;
for (const auto & pre_agg : aggregate_descriptions)
{
Expand All @@ -402,7 +418,7 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co
aggregate.column_name = func_string;
aggregate.parameters = Array();
aggregate.function = AggregateFunctionFactory::instance().get(agg_func_name, types, {}, 0, false);
aggregate.function->setCollator(getCollatorFromExpr(expr));
aggregate.function->setCollator(function_collator);
aggregate_descriptions.push_back(aggregate);
DataTypePtr result_type = aggregate.function->getReturnType();
// this is a temp result since implicit cast maybe added on these aggregated_columns
Expand Down Expand Up @@ -431,7 +447,7 @@ bool isUInt8Type(const DataTypePtr & type)
String DAGExpressionAnalyzer::applyFunction(
const String & func_name, const Names & arg_names, ExpressionActionsPtr & actions, std::shared_ptr<TiDB::ITiDBCollator> collator)
{
String result_name = genFuncString(func_name, arg_names);
String result_name = genFuncString(func_name, arg_names, {collator});
if (actions->getSampleBlock().has(result_name))
return result_name;
const FunctionBuilderPtr & function_builder = FunctionFactory::instance().get(func_name, context);
Expand All @@ -441,7 +457,9 @@ String DAGExpressionAnalyzer::applyFunction(
}

void DAGExpressionAnalyzer::appendWhere(
ExpressionActionsChain & chain, const std::vector<const tipb::Expr *> & conditions, String & filter_column_name)
ExpressionActionsChain & chain,
const std::vector<const tipb::Expr *> & conditions,
String & filter_column_name)
{
initChain(chain, getCurrentInputColumns());
ExpressionActionsChain::Step & last_step = chain.steps.back();
Expand Down Expand Up @@ -528,7 +546,9 @@ String DAGExpressionAnalyzer::convertToUInt8(ExpressionActionsPtr & actions, con
}

void DAGExpressionAnalyzer::appendOrderBy(
ExpressionActionsChain & chain, const tipb::TopN & topN, std::vector<NameAndTypePair> & order_columns)
ExpressionActionsChain & chain,
const tipb::TopN & topN,
std::vector<NameAndTypePair> & order_columns)
{
if (topN.order_by_size() == 0)
{
Expand Down Expand Up @@ -568,7 +588,10 @@ void constructTZExpr(tipb::Expr & tz_expr, const TimezoneInfo & dag_timezone_inf
}

String DAGExpressionAnalyzer::appendTimeZoneCast(
const String & tz_col, const String & ts_col, const String & func_name, ExpressionActionsPtr & actions)
const String & tz_col,
const String & ts_col,
const String & func_name,
ExpressionActionsPtr & actions)
{
String cast_expr_name = applyFunction(func_name, {ts_col, tz_col}, actions, nullptr);
return cast_expr_name;
Expand Down Expand Up @@ -614,12 +637,15 @@ bool DAGExpressionAnalyzer::appendTimeZoneCastsAfterTS(ExpressionActionsChain &
}

void DAGExpressionAnalyzer::appendJoin(
ExpressionActionsChain & chain, SubqueryForSet & join_query, const NamesAndTypesList & columns_added_by_join)
ExpressionActionsChain & chain,
SubqueryForSet & join_query,
const NamesAndTypesList & columns_added_by_join)
{
initChain(chain, getCurrentInputColumns());
ExpressionActionsPtr actions = chain.getLastActions();
actions->add(ExpressionAction::ordinaryJoin(join_query.join, columns_added_by_join));
}

/// return true if some actions is needed
bool DAGExpressionAnalyzer::appendJoinKey(ExpressionActionsChain & chain, const google::protobuf::RepeatedPtrField<tipb::Expr> & keys,
const DataTypes & key_types, Names & key_names, bool left, bool is_right_out_join)
Expand Down Expand Up @@ -774,7 +800,10 @@ void DAGExpressionAnalyzer::appendAggSelect(
* @return
*/
String DAGExpressionAnalyzer::alignReturnType(
const tipb::Expr & expr, ExpressionActionsPtr & actions, const String & expr_name, bool force_uint8)
const tipb::Expr & expr,
ExpressionActionsPtr & actions,
const String & expr_name,
bool force_uint8)
{
DataTypePtr orig_type = actions->getSampleBlock().getByName(expr_name).type;
if (force_uint8 && isUInt8Type(orig_type))
Expand All @@ -798,7 +827,10 @@ String DAGExpressionAnalyzer::appendCast(const DataTypePtr & target_type, Expres
}

String DAGExpressionAnalyzer::appendCastIfNeeded(
const tipb::Expr & expr, ExpressionActionsPtr & actions, const String & expr_name, bool explicit_cast)
const tipb::Expr & expr,
ExpressionActionsPtr & actions,
const String & expr_name,
bool explicit_cast)
{
if (!isFunctionExpr(expr))
return expr_name;
Expand All @@ -813,7 +845,6 @@ String DAGExpressionAnalyzer::appendCastIfNeeded(
DataTypePtr actual_type = actions->getSampleBlock().getByName(expr_name).type;
if (expected_type->getName() != actual_type->getName())
{

implicit_cast_count += !explicit_cast;
return appendCast(expected_type, actions, expr_name);
}
Expand All @@ -826,7 +857,10 @@ String DAGExpressionAnalyzer::appendCastIfNeeded(
}

void DAGExpressionAnalyzer::makeExplicitSet(
const tipb::Expr & expr, const Block & sample_block, bool create_ordered_set, const String & left_arg_name)
const tipb::Expr & expr,
const Block & sample_block,
bool create_ordered_set,
const String & left_arg_name)
{
if (prepared_sets.count(&expr))
{
Expand Down
15 changes: 15 additions & 0 deletions tests/tidb-ci/new_collation_fullstack/function_collator.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
mysql> drop table if exists test.t1
mysql> drop table if exists test.t2
mysql> create table test.t1(col_varchar_20_key_signed varchar(20) COLLATE utf8mb4_general_ci, col_varbinary_20_key_signed varbinary(20), col_varbinary_20_undef_signed varbinary(20));
mysql> create table test.t2(col_char_20_key_signed char(20) COLLATE utf8mb4_general_ci, col_varchar_20_undef_signed varchar(20) COLLATE utf8mb4_general_ci);
mysql> alter table test.t1 set tiflash replica 1
mysql> alter table test.t2 set tiflash replica 1
mysql> insert into test.t1 values('Abc',0x62,0x616263);
mysql> insert into test.t2 values('abc','b');
func> wait_table test t1
func> wait_table test t2

mysql> set session tidb_enforce_mpp=1; select * from test.t1 where t1.col_varchar_20_key_signed not in (select col_char_20_key_signed from test.t2 where t1.col_varchar_20_key_signed not in ( t1.col_varbinary_20_key_signed, t1.col_varbinary_20_undef_signed,col_varchar_20_undef_signed,col_char_20_key_signed));
JaySon-Huang marked this conversation as resolved.
Show resolved Hide resolved

mysql> drop table if exists test.t1;
mysql> drop table if exists test.t2;