diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp index 78dd014b19a..766fe149add 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp @@ -22,15 +22,15 @@ namespace DB { - namespace ErrorCodes { extern const int COP_BAD_DAG_REQUEST; extern const int UNSUPPORTED_METHOD; } // namespace ErrorCodes -static String genFuncString(const String & func_name, const Names & argument_names) +static String genFuncString(const String & func_name, const Names & argument_names, const TiDB::TiDBCollators & collators) { + assert(!collators.empty()); std::stringstream ss; ss << func_name << "("; bool first = true; @@ -46,7 +46,15 @@ static String genFuncString(const String & func_name, const Names & argument_nam } ss << argument_name; } - ss << ") "; + ss << ")_collator"; + for (const auto & collator : collators) + { + if (collator == nullptr) + ss << "_0"; + else + ss << "_" << collator->getCollatorId(); + } + ss << " "; return ss.str(); } @@ -162,10 +170,14 @@ static String buildLogicalFunction(DAGExpressionAnalyzer * analyzer, const tipb: static const String tidb_cast_name = "tidb_cast"; -static String buildCastFunctionInternal(DAGExpressionAnalyzer * analyzer, const Names & argument_names, bool in_union, - const tipb::FieldType & field_type, ExpressionActionsPtr & actions) +static String buildCastFunctionInternal( + DAGExpressionAnalyzer * analyzer, + const Names & argument_names, + bool in_union, + const tipb::FieldType & field_type, + ExpressionActionsPtr & actions) { - String result_name = genFuncString(tidb_cast_name, argument_names); + String result_name = genFuncString(tidb_cast_name, argument_names, {nullptr}); if (actions->getSampleBlock().has(result_name)) return result_name; @@ -200,7 +212,6 @@ static String buildCastFunction(DAGExpressionAnalyzer * analyzer, const tipb::Ex static String buildDateAddFunction(DAGExpressionAnalyzer * analyzer, const tipb::Expr & expr, ExpressionActionsPtr & actions) { - static const std::unordered_map unit_to_func_name_map({{"DAY", "addDays"}, {"WEEK", "addWeeks"}, {"MONTH", "addMonths"}, {"YEAR", "addYears"}, {"HOUR", "addHours"}, {"MINUTE", "addMinutes"}, {"SECOND", "addSeconds"}}); if (expr.children_size() != 3) @@ -303,7 +314,10 @@ static std::unordered_map && source_columns_, const Context & context_) - : source_columns(std::move(source_columns_)), context(context_), after_agg(false), implicit_cast_count(0) + : source_columns(std::move(source_columns_)) + , context(context_) + , after_agg(false) + , implicit_cast_count(0) { settings = context.getSettings(); } @@ -332,7 +346,8 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co aggregate.argument_names[i] = arg_name; step.required_output.push_back(arg_name); } - String func_string = genFuncString(agg_func_name, aggregate.argument_names); + auto function_collator = getCollatorFromExpr(expr); + String func_string = genFuncString(agg_func_name, aggregate.argument_names, {function_collator}); bool duplicate = false; for (const auto & pre_agg : aggregate_descriptions) { @@ -349,7 +364,7 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co aggregate.parameters = Array(); /// if there is group by clause, there is no need to consider the empty input case aggregate.function = AggregateFunctionFactory::instance().get(agg_func_name, types, {}, 0, agg.group_by_size() == 0); - aggregate.function->setCollator(getCollatorFromExpr(expr)); + aggregate.function->setCollator(function_collator); aggregate_descriptions.push_back(aggregate); DataTypePtr result_type = aggregate.function->getReturnType(); // this is a temp result since implicit cast maybe added on these aggregated_columns @@ -386,7 +401,8 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co types[0] = type; aggregate.argument_names[0] = name; - String func_string = genFuncString(agg_func_name, aggregate.argument_names); + auto function_collator = getCollatorFromExpr(expr); + String func_string = genFuncString(agg_func_name, aggregate.argument_names, {function_collator}); bool duplicate = false; for (const auto & pre_agg : aggregate_descriptions) { @@ -402,7 +418,7 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co aggregate.column_name = func_string; aggregate.parameters = Array(); aggregate.function = AggregateFunctionFactory::instance().get(agg_func_name, types, {}, 0, false); - aggregate.function->setCollator(getCollatorFromExpr(expr)); + aggregate.function->setCollator(function_collator); aggregate_descriptions.push_back(aggregate); DataTypePtr result_type = aggregate.function->getReturnType(); // this is a temp result since implicit cast maybe added on these aggregated_columns @@ -431,7 +447,7 @@ bool isUInt8Type(const DataTypePtr & type) String DAGExpressionAnalyzer::applyFunction( const String & func_name, const Names & arg_names, ExpressionActionsPtr & actions, std::shared_ptr collator) { - String result_name = genFuncString(func_name, arg_names); + String result_name = genFuncString(func_name, arg_names, {collator}); if (actions->getSampleBlock().has(result_name)) return result_name; const FunctionBuilderPtr & function_builder = FunctionFactory::instance().get(func_name, context); @@ -441,7 +457,9 @@ String DAGExpressionAnalyzer::applyFunction( } void DAGExpressionAnalyzer::appendWhere( - ExpressionActionsChain & chain, const std::vector & conditions, String & filter_column_name) + ExpressionActionsChain & chain, + const std::vector & conditions, + String & filter_column_name) { initChain(chain, getCurrentInputColumns()); ExpressionActionsChain::Step & last_step = chain.steps.back(); @@ -528,7 +546,9 @@ String DAGExpressionAnalyzer::convertToUInt8(ExpressionActionsPtr & actions, con } void DAGExpressionAnalyzer::appendOrderBy( - ExpressionActionsChain & chain, const tipb::TopN & topN, std::vector & order_columns) + ExpressionActionsChain & chain, + const tipb::TopN & topN, + std::vector & order_columns) { if (topN.order_by_size() == 0) { @@ -568,7 +588,10 @@ void constructTZExpr(tipb::Expr & tz_expr, const TimezoneInfo & dag_timezone_inf } String DAGExpressionAnalyzer::appendTimeZoneCast( - const String & tz_col, const String & ts_col, const String & func_name, ExpressionActionsPtr & actions) + const String & tz_col, + const String & ts_col, + const String & func_name, + ExpressionActionsPtr & actions) { String cast_expr_name = applyFunction(func_name, {ts_col, tz_col}, actions, nullptr); return cast_expr_name; @@ -614,12 +637,15 @@ bool DAGExpressionAnalyzer::appendTimeZoneCastsAfterTS(ExpressionActionsChain & } void DAGExpressionAnalyzer::appendJoin( - ExpressionActionsChain & chain, SubqueryForSet & join_query, const NamesAndTypesList & columns_added_by_join) + ExpressionActionsChain & chain, + SubqueryForSet & join_query, + const NamesAndTypesList & columns_added_by_join) { initChain(chain, getCurrentInputColumns()); ExpressionActionsPtr actions = chain.getLastActions(); actions->add(ExpressionAction::ordinaryJoin(join_query.join, columns_added_by_join)); } + /// return true if some actions is needed bool DAGExpressionAnalyzer::appendJoinKey(ExpressionActionsChain & chain, const google::protobuf::RepeatedPtrField & keys, const DataTypes & key_types, Names & key_names, bool left, bool is_right_out_join) @@ -774,7 +800,10 @@ void DAGExpressionAnalyzer::appendAggSelect( * @return */ String DAGExpressionAnalyzer::alignReturnType( - const tipb::Expr & expr, ExpressionActionsPtr & actions, const String & expr_name, bool force_uint8) + const tipb::Expr & expr, + ExpressionActionsPtr & actions, + const String & expr_name, + bool force_uint8) { DataTypePtr orig_type = actions->getSampleBlock().getByName(expr_name).type; if (force_uint8 && isUInt8Type(orig_type)) @@ -798,7 +827,10 @@ String DAGExpressionAnalyzer::appendCast(const DataTypePtr & target_type, Expres } String DAGExpressionAnalyzer::appendCastIfNeeded( - const tipb::Expr & expr, ExpressionActionsPtr & actions, const String & expr_name, bool explicit_cast) + const tipb::Expr & expr, + ExpressionActionsPtr & actions, + const String & expr_name, + bool explicit_cast) { if (!isFunctionExpr(expr)) return expr_name; @@ -813,7 +845,6 @@ String DAGExpressionAnalyzer::appendCastIfNeeded( DataTypePtr actual_type = actions->getSampleBlock().getByName(expr_name).type; if (expected_type->getName() != actual_type->getName()) { - implicit_cast_count += !explicit_cast; return appendCast(expected_type, actions, expr_name); } @@ -826,7 +857,10 @@ String DAGExpressionAnalyzer::appendCastIfNeeded( } void DAGExpressionAnalyzer::makeExplicitSet( - const tipb::Expr & expr, const Block & sample_block, bool create_ordered_set, const String & left_arg_name) + const tipb::Expr & expr, + const Block & sample_block, + bool create_ordered_set, + const String & left_arg_name) { if (prepared_sets.count(&expr)) { diff --git a/tests/tidb-ci/new_collation_fullstack/function_collator.test b/tests/tidb-ci/new_collation_fullstack/function_collator.test new file mode 100644 index 00000000000..8cc177738ad --- /dev/null +++ b/tests/tidb-ci/new_collation_fullstack/function_collator.test @@ -0,0 +1,15 @@ +mysql> drop table if exists test.t1 +mysql> drop table if exists test.t2 +mysql> create table test.t1(col_varchar_20_key_signed varchar(20) COLLATE utf8mb4_general_ci, col_varbinary_20_key_signed varbinary(20), col_varbinary_20_undef_signed varbinary(20)); +mysql> create table test.t2(col_char_20_key_signed char(20) COLLATE utf8mb4_general_ci, col_varchar_20_undef_signed varchar(20) COLLATE utf8mb4_general_ci); +mysql> alter table test.t1 set tiflash replica 1 +mysql> alter table test.t2 set tiflash replica 1 +mysql> insert into test.t1 values('Abc',0x62,0x616263); +mysql> insert into test.t2 values('abc','b'); +func> wait_table test t1 +func> wait_table test t2 + +mysql> set @@tidb_isolation_read_engines='tiflash'; select * from test.t1 where t1.col_varchar_20_key_signed not in (select col_char_20_key_signed from test.t2 where t1.col_varchar_20_key_signed not in ( t1.col_varbinary_20_key_signed, t1.col_varbinary_20_undef_signed,col_varchar_20_undef_signed,col_char_20_key_signed)); + +mysql> drop table if exists test.t1; +mysql> drop table if exists test.t2;