Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
fzhedu committed Aug 30, 2021
1 parent f00e570 commit b5e21f2
Show file tree
Hide file tree
Showing 5 changed files with 276 additions and 268 deletions.
271 changes: 271 additions & 0 deletions dbms/src/AggregateFunctions/AggregateFunctionGroupConcat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
#include <AggregateFunctions/AggregateFunctionGroupUniqArray.h>
#include <AggregateFunctions/AggregateFunctionNull.h>


namespace DB
{

/// a warp function on the top of groupArray and groupUniqArray, like the AggregateFunctionNull
///
/// the input argument is in following two types:
/// 1. only one column with original data type and without order_by items, for example: group_concat(c)
/// 2. one column combined with more than one columns including concat items and order-by items, it should be like tuple(concat0, concat1... order0, order1 ...), for example:
/// all columns = concat items + order-by items
/// (c0,c1,o0,o1) = group_concat(c0,c1 order by o0,o1)
/// group_concat(distinct c0,c1 order by b0,b1) = groupUniqArray(tuple(c0,c1,b0,b1)) -> distinct (c0, c1) , i.e., remove duplicates further

template <bool result_is_nullable, bool only_one_column>
class AggregateFunctionGroupConcat final : public AggregateFunctionNullBase<result_is_nullable, AggregateFunctionGroupConcat<result_is_nullable, only_one_column>>
{
using State = AggregateFunctionGroupUniqArrayGenericData;

public:
AggregateFunctionGroupConcat(AggregateFunctionPtr nested_function, const DataTypes & input_args, const String& sep, const UInt64& max_len_, const SortDescription & sort_desc_, const NamesAndTypes& all_columns_names_and_types_, const TiDB::TiDBCollators& collators_, const bool has_distinct)
: AggregateFunctionNullBase<result_is_nullable, AggregateFunctionGroupConcat<result_is_nullable, only_one_column>>(nested_function),
separator(sep),max_len(max_len_), sort_desc(sort_desc_),
all_columns_names_and_types(all_columns_names_and_types_), collators(collators_)
{
if (input_args.size() != 1)
throw Exception("Logical error: more than 1 arguments are passed to AggregateFunctionGroupConcat", ErrorCodes::LOGICAL_ERROR);
nested_type = std::make_shared<DataTypeArray>(removeNullable(input_args[0]));

number_of_concat_items = all_columns_names_and_types.size() - sort_desc.size();

is_nullable.resize(number_of_concat_items);
for (size_t i = 0; i < number_of_concat_items; ++i)
{
is_nullable[i] = all_columns_names_and_types[i].type->isNullable();
/// the inputs of a nested agg reject null, but for more than one args, tuple(args...) is already not nullable,
/// so here just remove null for the only_one_column case
if constexpr (only_one_column)
{
all_columns_names_and_types[i].type = removeNullable(all_columns_names_and_types[i].type);
}
}

/// remove redundant rows excluding extra sort items (which do not occur in the concat list) or considering collation
if(has_distinct)
{
for (auto & desc : sort_desc)
{
bool is_extra = true;
for (size_t i = 0; i < number_of_concat_items; ++i)
{
if (desc.column_name == all_columns_names_and_types[i].name)
{
is_extra = false;
break;
}
}
if (is_extra)
{
to_get_unique = true;
break;
}
}
/// because GroupUniqArray does consider collations, so if there are collations,
/// we should additionally remove redundant rows with consideration of collations
if(!to_get_unique)
{
bool has_collation = false;
for (size_t i = 0; i < number_of_concat_items; ++i)
{
if (collators[i] != nullptr)
{
has_collation = true;
break;
}
}
to_get_unique = has_collation;
}
}
}

DataTypePtr getReturnType() const override
{
return result_is_nullable
? makeNullable(ret_type)
: ret_type;
}

/// reject nulls before add() of nested agg
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
if constexpr (only_one_column)
{
if(is_nullable[0])
{
const ColumnNullable * column = static_cast<const ColumnNullable *>(columns[0]);
if (!column->isNullAt(row_num))
{
this->setFlag(place);
const IColumn * nested_column = &column->getNestedColumn();
this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);
}
return;
}
}
else
{
/// remove the row with null, except for sort columns
const ColumnTuple & tuple = static_cast<const ColumnTuple &>(*columns[0]);
for (size_t i = 0; i < number_of_concat_items; ++i)
{
if (is_nullable[i])
{
const ColumnNullable & nullable_col = static_cast<const ColumnNullable &>(tuple.getColumn(i));
if (nullable_col.isNullAt(row_num))
{
/// If at least one column has a null value in the current row,
/// we don't process this row.
return;
}
}
}
}
this->setFlag(place);
this->nested_function->add(this->nestedPlace(place), columns, row_num, arena);
}

void insertResultInto(ConstAggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
{
ColumnString * col_str= nullptr;
ColumnNullable * col_null= nullptr;
if constexpr (result_is_nullable)
{
col_null = &static_cast<ColumnNullable &>(to);
col_str = &static_cast<ColumnString &>(col_null->getNestedColumn());
}
else
{
col_str = & static_cast<ColumnString &>(to);
}

if (this->getFlag(place))
{
if constexpr (result_is_nullable)
{
col_null->getNullMapData().push_back(0);
}

/// get results from nested function, named nested_results
auto mutable_nested_cols = nested_type->createColumn();
this->nested_function->insertResultInto(this->nestedPlace(place), *mutable_nested_cols, arena);
const auto column_array = checkAndGetColumn<ColumnArray>(mutable_nested_cols.get());

/// nested_columns are not nullable, because the nullable rows are removed in add()
Columns nested_cols;
if constexpr (only_one_column)
{
nested_cols.push_back(column_array->getDataPtr());
}
else
{
auto & cols = checkAndGetColumn<ColumnTuple>(&column_array->getData())->getColumns();
nested_cols.insert(nested_cols.begin(),cols.begin(),cols.end());
}

/// sort the nested_col of Array type
if(!sort_desc.empty())
sortColumns(nested_cols);

/// get unique flags
std::vector<bool> unique;
if (to_get_unique)
getUnique(nested_cols, unique);

writeToStringColumn(nested_cols,col_str, unique);

}
else
{
if constexpr (result_is_nullable)
col_null->insertDefault();
else
col_str->insertDefault();
}
}

bool allocatesMemoryInArena() const override
{
return this->nested_function->allocatesMemoryInArena();
}

private:
/// construct a block to sort in the case with order-by requirement
void sortColumns(Columns& nested_cols) const
{
Block res;
int concat_size = nested_cols.size();
for(int i = 0 ; i < concat_size; ++i )
{
res.insert(ColumnWithTypeAndName(nested_cols[i], all_columns_names_and_types[i].type, all_columns_names_and_types[i].name));
}
/// sort a block with collation
sortBlock(res, sort_desc);
nested_cols = res.getColumns();
}

/// get unique argument columns by inserting the unique of the first N of (N + M sort) internal columns within tuple
void getUnique(const Columns & cols, std::vector<bool> & unique) const
{
std::unique_ptr<State> state = std::make_unique<State>();
Arena arena1;
auto size = cols[0]->size();
unique.resize(size);
std::vector<String> containers(collators.size());
for (size_t i = 0; i < size; ++i)
{
bool inserted=false;
State::Set::LookupResult it;
const char * begin = nullptr;
size_t values_size = 0;
for (size_t j = 0; j< number_of_concat_items; ++j)
values_size += cols[j]->serializeValueIntoArena(i, arena1, begin, collators[j],containers[j]).size;

StringRef str_serialized= StringRef(begin, values_size);
state->value.emplace(str_serialized, it, inserted);
unique[i] = inserted;
}
}

/// write each column cell to string with separator
void writeToStringColumn(const Columns& cols, ColumnString * const col_str, const std::vector<bool> & unique) const
{
WriteBufferFromOwnString write_buffer;
auto size = cols[0]->size();
for (size_t i = 0; i < size; ++i)
{
if(unique.empty() || unique[i])
{
if (i != 0)
{
writeString(separator, write_buffer);
}
for (size_t j = 0; j < number_of_concat_items; ++j)
{
all_columns_names_and_types[j].type->serializeText(*cols[j], i, write_buffer);
}
}
/// TODO(FZH) output just one warning ("Some rows were cut by GROUPCONCAT()") if this happen
if(write_buffer.count() >=max_len)
{
break;
}
}
col_str->insertData(write_buffer.str().c_str(),std::min(max_len,write_buffer.count()));
}

bool to_get_unique =false;
DataTypePtr ret_type = std::make_shared<DataTypeString>();
DataTypePtr nested_type;
size_t number_of_concat_items = 0;
String separator =",";
UInt64 max_len;
SortDescription sort_desc;
NamesAndTypes all_columns_names_and_types;
TiDB::TiDBCollators collators;
BoolVec is_nullable;
};
}

6 changes: 3 additions & 3 deletions dbms/src/AggregateFunctions/AggregateFunctionNull.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ class AggregateFunctionCombinatorNull final : public IAggregateFunctionCombinato
AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array &) const override
{
/// group_concat reuses groupArray and groupUniqArray, it has the special warp function `AggregateFunctionGroupConcat` to process
/// the issues of null, but the warp function needs more complex arguments, it is specially added outside,
/// instead of being added here, so directly return in this function.
/// group_concat reuses groupArray and groupUniqArray with the special warp function `AggregateFunctionGroupConcat` to process,
/// the warp function needs more complex arguments, including collators, sort descriptions and others, which are hard to deliver via Array type,
/// so it is specially added outside, instead of being added here, so directly return in this function.
if (nested_function && (nested_function->getName() == "groupArray" || nested_function->getName() == "groupUniqArray"))
return nested_function;
bool has_nullable_types = false;
Expand Down
Loading

0 comments on commit b5e21f2

Please sign in to comment.