Skip to content

Commit

Permalink
Test: support qualified column reference in writing tests (#5614)
Browse files Browse the repository at this point in the history
close #5510
  • Loading branch information
ywqzzy authored Aug 16, 2022
1 parent a89c8b6 commit 49d8050
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 80 deletions.
86 changes: 36 additions & 50 deletions dbms/src/Debug/astToExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ String getFunctionNameForConstantFolding(tipb::Expr * expr)
}
}


void foldConstant(tipb::Expr * expr, int32_t collator_id, const Context & context)
{
if (expr->tp() == tipb::ScalarFunc)
Expand Down Expand Up @@ -286,7 +285,6 @@ void functionToPB(const DAGSchema & input, ASTFunction * func, tipb::Expr * expr

void identifierToPB(const DAGSchema & input, ASTIdentifier * id, tipb::Expr * expr, int32_t collator_id);


void astToPB(const DAGSchema & input, ASTPtr ast, tipb::Expr * expr, int32_t collator_id, const Context & context)
{
if (auto * id = typeid_cast<ASTIdentifier *>(ast.get()))
Expand All @@ -307,17 +305,23 @@ void astToPB(const DAGSchema & input, ASTPtr ast, tipb::Expr * expr, int32_t col
}
}

void functionToPB(const DAGSchema & input, ASTFunction * func, tipb::Expr * expr, int32_t collator_id, const Context & context)
auto checkSchema(const DAGSchema & input, String checked_column)
{
/// aggregation function is handled in Aggregation, so just treated as a column
auto ft = std::find_if(input.begin(), input.end(), [&](const auto & field) {
auto column_name = splitQualifiedName(func->getColumnName());
auto field_name = splitQualifiedName(field.first);
if (column_name.first.empty())
return field_name.second == column_name.second;
auto [checked_db_name, checked_table_name, checked_column_name] = splitQualifiedName(checked_column);
auto [db_name, table_name, column_name] = splitQualifiedName(field.first);
if (checked_table_name.empty())
return column_name == checked_column_name;
else
return field_name.first == column_name.first && field_name.second == column_name.second;
return table_name == checked_table_name && column_name == checked_column_name;
});
return ft;
}

void functionToPB(const DAGSchema & input, ASTFunction * func, tipb::Expr * expr, int32_t collator_id, const Context & context)
{
/// aggregation function is handled in Aggregation, so just treated as a column
auto ft = checkSchema(input, func->getColumnName());
if (ft != input.end())
{
expr->set_tp(tipb::ColumnRef);
Expand Down Expand Up @@ -520,14 +524,7 @@ void functionToPB(const DAGSchema & input, ASTFunction * func, tipb::Expr * expr

void identifierToPB(const DAGSchema & input, ASTIdentifier * id, tipb::Expr * expr, int32_t collator_id)
{
auto ft = std::find_if(input.begin(), input.end(), [&](const auto & field) {
auto column_name = splitQualifiedName(id->getColumnName());
auto field_name = splitQualifiedName(field.first);
if (column_name.first.empty())
return field_name.second == column_name.second;
else
return field_name.first == column_name.first && field_name.second == column_name.second;
});
auto ft = checkSchema(input, id->getColumnName());
if (ft == input.end())
throw Exception("No such column " + id->getColumnName(), ErrorCodes::NO_SUCH_COLUMN_IN_TABLE);
expr->set_tp(tipb::ColumnRef);
Expand All @@ -542,19 +539,18 @@ void collectUsedColumnsFromExpr(const DAGSchema & input, ASTPtr ast, std::unorde
{
if (auto * id = typeid_cast<ASTIdentifier *>(ast.get()))
{
auto column_name = splitQualifiedName(id->getColumnName());
if (!column_name.first.empty())
auto [db_name, table_name, column_name] = splitQualifiedName(id->getColumnName());
if (!table_name.empty())
used_columns.emplace(id->getColumnName());
else
{
bool found = false;
for (const auto & field : input)
{
auto field_name = splitQualifiedName(field.first);
if (field_name.second == column_name.second)
if (splitQualifiedName(field.first).column_name == column_name)
{
if (found)
throw Exception("ambiguous column for " + column_name.second);
throw Exception("ambiguous column for " + column_name);
found = true;
used_columns.emplace(field.first);
}
Expand All @@ -570,14 +566,7 @@ void collectUsedColumnsFromExpr(const DAGSchema & input, ASTPtr ast, std::unorde
else
{
/// check function
auto ft = std::find_if(input.begin(), input.end(), [&](const auto & field) {
auto column_name = splitQualifiedName(func->getColumnName());
auto field_name = splitQualifiedName(field.first);
if (column_name.first.empty())
return field_name.second == column_name.second;
else
return field_name.first == column_name.first && field_name.second == column_name.second;
});
auto ft = checkSchema(input, func->getColumnName());
if (ft != input.end())
{
used_columns.emplace(func->getColumnName());
Expand All @@ -597,14 +586,7 @@ TiDB::ColumnInfo compileExpr(const DAGSchema & input, ASTPtr ast)
if (auto * id = typeid_cast<ASTIdentifier *>(ast.get()))
{
/// check column
auto ft = std::find_if(input.begin(), input.end(), [&](const auto & field) {
auto column_name = splitQualifiedName(id->getColumnName());
auto field_name = splitQualifiedName(field.first);
if (column_name.first.empty())
return field_name.second == column_name.second;
else
return field_name.first == column_name.first && field_name.second == column_name.second;
});
auto ft = checkSchema(input, id->getColumnName());
if (ft == input.end())
throw Exception("No such column " + id->getColumnName(), ErrorCodes::NO_SUCH_COLUMN_IN_TABLE);
ci = ft->second;
Expand Down Expand Up @@ -793,24 +775,28 @@ void setServiceAddr(const std::string & addr)
}
} // namespace Debug

std::pair<String, String> splitQualifiedName(const String & s)
ColumnName splitQualifiedName(const String & s)
{
std::pair<String, String> ret;
ColumnName ret;
Poco::StringTokenizer string_tokens(s, ".");

switch (string_tokens.count())
{
case 1:
ret.second = s;
ret.column_name = s;
break;
case 2:
ret.first = string_tokens[0];
ret.second = string_tokens[1];
ret.table_name = string_tokens[0];
ret.column_name = string_tokens[1];
break;
case 3:
ret.db_name = string_tokens[0];
ret.table_name = string_tokens[1];
ret.column_name = string_tokens[2];
break;
default:
throw Exception("Invalid identifier name " + s);
}

return ret;
}

Expand Down Expand Up @@ -1217,15 +1203,15 @@ void Join::columnPrune(std::unordered_set<String> & used_columns)
auto col_name = identifier->getColumnName();
for (auto & field : children[0]->output_schema)
{
if (col_name == splitQualifiedName(field.first).second)
if (col_name == splitQualifiedName(field.first).column_name)
{
left_used_columns.emplace(field.first);
break;
}
}
for (auto & field : children[1]->output_schema)
{
if (col_name == splitQualifiedName(field.first).second)
if (col_name == splitQualifiedName(field.first).column_name)
{
right_used_columns.emplace(field.first);
break;
Expand Down Expand Up @@ -1272,7 +1258,7 @@ void Join::fillJoinKeyAndFieldType(
{
const auto & [col_name, col_info] = child_schema[index];

if (splitQualifiedName(col_name).second == identifier->getColumnName())
if (splitQualifiedName(col_name).column_name == identifier->getColumnName())
{
auto tipb_type = TiDB::columnInfoToFieldType(col_info);
tipb_type.set_collate(collator_id);
Expand Down Expand Up @@ -1361,7 +1347,7 @@ void Join::toMPPSubPlan(size_t & executor_index, const DAGProperties & propertie
auto push_back_partition_key = [](auto & partition_keys, const auto & child_schema, const auto & key) {
for (size_t index = 0; index < child_schema.size(); ++index)
{
if (splitQualifiedName(child_schema[index].first).second == key->getColumnName())
if (splitQualifiedName(child_schema[index].first).column_name == key->getColumnName())
{
partition_keys.push_back(index);
break;
Expand Down Expand Up @@ -1495,7 +1481,7 @@ bool Sort::toTiPBExecutor(tipb::Executor * tipb_executor, int32_t collator_id, c

} // namespace mock

ExecutorPtr compileTableScan(size_t & executor_index, TableInfo & table_info, String & table_alias, bool append_pk_column)
ExecutorPtr compileTableScan(size_t & executor_index, TableInfo & table_info, const String & db, const String & table_name, bool append_pk_column)
{
DAGSchema ts_output;
for (const auto & column_info : table_info.columns)
Expand All @@ -1510,7 +1496,7 @@ ExecutorPtr compileTableScan(size_t & executor_index, TableInfo & table_info, St
ci.origin_default_value = column_info.origin_default_value;
/// use qualified name as the column name to handle multiple table queries, not very
/// efficient but functionally enough for mock test
ts_output.emplace_back(std::make_pair(table_alias + "." + column_info.name, std::move(ci)));
ts_output.emplace_back(std::make_pair(db + "." + table_name + "." + column_info.name, std::move(ci)));
}
if (append_pk_column)
{
Expand Down
20 changes: 15 additions & 5 deletions dbms/src/Debug/astToExecutor.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,17 @@ extern String LOCAL_HOST;
void setServiceAddr(const std::string & addr);
} // namespace Debug

std::pair<String, String> splitQualifiedName(const String & s);
// We use qualified format like "db_name.table_name.column_name"
// to identify one column of a table.
// We can split the qualified format into the ColumnName struct.
struct ColumnName
{
String db_name;
String table_name;
String column_name;
};

ColumnName splitQualifiedName(const String & s);

struct MPPCtx
{
Expand Down Expand Up @@ -165,11 +175,11 @@ struct TableScan : public Executor

void setTipbColumnInfo(tipb::ColumnInfo * ci, const DAGColumnInfo & dag_column_info) const
{
auto column_name = splitQualifiedName(dag_column_info.first).second;
if (column_name == MutableSupport::tidb_pk_column_name)
auto names = splitQualifiedName(dag_column_info.first);
if (names.column_name == MutableSupport::tidb_pk_column_name)
ci->set_column_id(-1);
else
ci->set_column_id(table_info.getColumnID(column_name));
ci->set_column_id(table_info.getColumnID(names.column_name));
ci->set_tp(dag_column_info.second.tp);
ci->set_flag(dag_column_info.second.flag);
ci->set_columnlen(dag_column_info.second.flen);
Expand Down Expand Up @@ -343,7 +353,7 @@ struct Sort : Executor

using ExecutorPtr = std::shared_ptr<mock::Executor>;

ExecutorPtr compileTableScan(size_t & executor_index, TableInfo & table_info, String & table_alias, bool append_pk_column);
ExecutorPtr compileTableScan(size_t & executor_index, TableInfo & table_info, const String & db, const String & table_name, bool append_pk_column);

ExecutorPtr compileSelection(ExecutorPtr input, size_t & executor_index, ASTPtr filter);

Expand Down
20 changes: 10 additions & 10 deletions dbms/src/Debug/dbgFuncCoprocessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,13 +767,13 @@ std::pair<ExecutorPtr, bool> compileQueryBlock(
}
}
}
root_executor = compileTableScan(executor_index, table_info, table_alias, append_pk_column);
root_executor = compileTableScan(executor_index, table_info, "", table_alias, append_pk_column);
}
}
else
{
TableInfo left_table_info = table_info;
String left_table_alias = table_alias;
auto const & left_table_alias = table_alias;
TableInfo right_table_info;
String right_table_alias;
{
Expand Down Expand Up @@ -812,24 +812,24 @@ std::pair<ExecutorPtr, bool> compileQueryBlock(
{
if (auto * identifier = typeid_cast<ASTIdentifier *>(expr.get()))
{
auto names = splitQualifiedName(identifier->getColumnName());
if (names.second == MutableSupport::tidb_pk_column_name)
auto [db_name, table_name, column_name] = splitQualifiedName(identifier->getColumnName());
if (column_name == MutableSupport::tidb_pk_column_name)
{
if (names.first.empty())
if (table_name.empty())
{
throw Exception("tidb pk column must be qualified since there are more than one tables");
}
if (names.first == left_table_alias)
if (table_name == left_table_alias)
left_append_pk_column = true;
else if (names.first == right_table_alias)
else if (table_name == right_table_alias)
right_append_pk_column = true;
else
throw Exception("Unknown table alias: " + names.first);
throw Exception("Unknown table alias: " + table_name);
}
}
}
auto left_ts = compileTableScan(executor_index, left_table_info, left_table_alias, left_append_pk_column);
auto right_ts = compileTableScan(executor_index, right_table_info, right_table_alias, right_append_pk_column);
auto left_ts = compileTableScan(executor_index, left_table_info, "", left_table_alias, left_append_pk_column);
auto right_ts = compileTableScan(executor_index, right_table_info, "", right_table_alias, right_append_pk_column);
root_executor = compileJoin(executor_index, left_ts, right_ts, joined_table->table_join);
}

Expand Down
11 changes: 6 additions & 5 deletions dbms/src/Flash/tests/gtest_compute_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ try

size_t task_size = tasks.size();
std::vector<String> expected_strings = {
"exchange_sender_6 | type:Hash, {<0, String>}\n"
" table_scan_1 | {<0, String>}",
"exchange_sender_6 | type:Hash, {<0, String>, <1, String>}\n"
" table_scan_1 | {<0, String>, <1, String>}",
"exchange_sender_5 | type:Hash, {<0, String>, <1, String>}\n"
" table_scan_0 | {<0, String>, <1, String>}",
"exchange_sender_4 | type:PassThrough, {<0, String>, <1, String>, <2, String>}\n"
"exchange_sender_4 | type:PassThrough, {<0, String>, <1, String>, <2, String>, <3, String>}\n"
" topn_3 | order_by: {(<1, String>, desc: false)}, limit: 2\n"
" Join_2 | LeftOuterJoin, HashJoin. left_join_keys: {<0, String>}, right_join_keys: {<0, String>}\n"
" exchange_receiver_7 | type:PassThrough, {<0, String>, <1, String>}\n"
" exchange_receiver_8 | type:PassThrough, {<0, String>}"};
" exchange_receiver_8 | type:PassThrough, {<0, String>, <1, String>}"};
for (size_t i = 0; i < task_size; ++i)
{
ASSERT_DAGREQUEST_EQAUL(expected_strings[i], tasks[i].dag_request);
Expand All @@ -100,7 +100,8 @@ try
auto expected_cols = {
toNullableVec<String>({{}, "banana"}),
toNullableVec<String>({{}, "apple"}),
toNullableVec<String>({{}, {}})};
toNullableVec<String>({{}, "banana"}),
toNullableVec<String>({{}, "apple"})};
ASSERT_MPPTASK_EQUAL(tasks, expected_cols);
}
CATCH
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace DB
{
namespace tests
{
class ExecutorTestRunner : public DB::tests::ExecutorTest
class FilterExecutorTestRunner : public DB::tests::ExecutorTest
{
public:
void initializeContext() override
Expand All @@ -36,7 +36,7 @@ class ExecutorTestRunner : public DB::tests::ExecutorTest
}
};

TEST_F(ExecutorTestRunner, Filter)
TEST_F(FilterExecutorTestRunner, equals)
try
{
WRAP_FOR_DIS_ENABLE_PLANNER_BEGIN
Expand All @@ -62,5 +62,22 @@ try
}
CATCH

TEST_F(FilterExecutorTestRunner, FilterWithQualifiedFormat)
try
{
auto request = context
.scan("test_db", "test_table")
.filter(eq(col("test_table.s1"), col("test_table.s2")))
.build(context);
{
ASSERT_COLUMNS_EQ_R(executeStreams(request),
createColumns({toNullableVec<String>({"banana"}),
toNullableVec<String>({"banana"})}));
}
}
CATCH

/// TODO: more functions.

} // namespace tests
} // namespace DB
Loading

0 comments on commit 49d8050

Please sign in to comment.