Skip to content

Commit

Permalink
FLASH-948 fix udaf behavior for empty input (#481)
Browse files Browse the repository at this point in the history
* udafs other than count should return NULL if input is empty

* fix bugs

* fix bug

* add tests

* address comment

* enhance tests

Co-authored-by: ruoxi <[email protected]>
  • Loading branch information
windtalker and zanmato1984 authored Feb 27, 2020
1 parent afcec7f commit 15b2183
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 18 deletions.
7 changes: 5 additions & 2 deletions dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,14 @@ AggregateFunctionPtr AggregateFunctionFactory::get(
const String & name,
const DataTypes & argument_types,
const Array & parameters,
int recursion_level) const
int recursion_level,
bool empty_input_as_null) const
{
/// If one of types is Nullable, we apply aggregate function combinator "Null".

if (std::any_of(argument_types.begin(), argument_types.end(),
/// for most aggregation functions except `count`, if the input is empty, the function should return NULL
/// so add this flag to make it possible to follow this rule, currently only used by Coprocessor query
if (empty_input_as_null || std::any_of(argument_types.begin(), argument_types.end(),
[](const auto & type) { return type->isNullable(); }))
{
AggregateFunctionCombinatorPtr combinator = AggregateFunctionCombinatorFactory::instance().tryFindSuffix("Null");
Expand Down
3 changes: 2 additions & 1 deletion dbms/src/AggregateFunctions/AggregateFunctionFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ class AggregateFunctionFactory final : public ext::singleton<AggregateFunctionFa
const String & name,
const DataTypes & argument_types,
const Array & parameters = {},
int recursion_level = 0) const;
int recursion_level = 0,
bool empty_input_as_null = false) const;

/// Returns nullptr if not found.
AggregateFunctionPtr tryGet(
Expand Down
29 changes: 22 additions & 7 deletions dbms/src/AggregateFunctions/AggregateFunctionNull.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,22 @@ class AggregateFunctionCombinatorNull final : public IAggregateFunctionCombinato
}
}

if (!has_nullable_types)
throw Exception("Aggregate function combinator 'Null' requires at least one argument to be Nullable", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

/// Special case for 'count' function. It could be called with Nullable arguments
/// - that means - count number of calls, when all arguments are not NULL.
if (nested_function && nested_function->getName() == "count")
{
if (arguments.size() == 1)
return std::make_shared<AggregateFunctionCountNotNullUnary>(arguments[0]);
if(has_nullable_types)
{
if (arguments.size() == 1)
return std::make_shared<AggregateFunctionCountNotNullUnary>(arguments[0]);
else
return std::make_shared<AggregateFunctionCountNotNullVariadic>(arguments);
}
else
return std::make_shared<AggregateFunctionCountNotNullVariadic>(arguments);
{
return std::make_shared<AggregateFunctionCount>();
}
}

if (has_null_types)
Expand All @@ -66,9 +71,19 @@ class AggregateFunctionCombinatorNull final : public IAggregateFunctionCombinato
if (arguments.size() == 1)
{
if (return_type_is_nullable)
return std::make_shared<AggregateFunctionNullUnary<true>>(nested_function);
{
if (has_nullable_types)
return std::make_shared<AggregateFunctionNullUnary<true, true>>(nested_function);
else
return std::make_shared<AggregateFunctionNullUnary<true, false>>(nested_function);
}
else
return std::make_shared<AggregateFunctionNullUnary<false>>(nested_function);
{
if (has_nullable_types)
return std::make_shared<AggregateFunctionNullUnary<false, true>>(nested_function);
else
return std::make_shared<AggregateFunctionNullUnary<false, false>>(nested_function);
}
}
else
{
Expand Down
22 changes: 15 additions & 7 deletions dbms/src/AggregateFunctions/AggregateFunctionNull.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,23 +183,31 @@ class AggregateFunctionNullBase : public IAggregateFunctionHelper<Derived>
/** There are two cases: for single argument and variadic.
* Code for single argument is much more efficient.
*/
template <bool result_is_nullable>
class AggregateFunctionNullUnary final : public AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullUnary<result_is_nullable>>
template <bool result_is_nullable, bool input_is_nullable>
class AggregateFunctionNullUnary final : public AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullUnary<result_is_nullable, input_is_nullable>>
{
public:
AggregateFunctionNullUnary(AggregateFunctionPtr nested_function)
: AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullUnary<result_is_nullable>>(nested_function)
: AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullUnary<result_is_nullable, input_is_nullable>>(nested_function)
{
}

void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
const ColumnNullable * column = static_cast<const ColumnNullable *>(columns[0]);
if (!column->isNullAt(row_num))
if constexpr (input_is_nullable)
{
const ColumnNullable * column = static_cast<const ColumnNullable *>(columns[0]);
if (!column->isNullAt(row_num))
{
this->setFlag(place);
const IColumn * nested_column = &column->getNestedColumn();
this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);
}
}
else
{
this->setFlag(place);
const IColumn * nested_column = &column->getNestedColumn();
this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);
this->nested_function->add(this->nestedPlace(place), columns, row_num, arena);
}
}
};
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ void DAGExpressionAnalyzer::appendAggregation(
continue;
aggregate.column_name = func_string;
aggregate.parameters = Array();
aggregate.function = AggregateFunctionFactory::instance().get(agg_func_name, types);
aggregate.function = AggregateFunctionFactory::instance().get(agg_func_name, types, {}, 0, true);
aggregate_descriptions.push_back(aggregate);
DataTypePtr result_type = aggregate.function->getReturnType();
// this is a temp result since implicit cast maybe added on these aggregated_columns
Expand Down
66 changes: 66 additions & 0 deletions tests/fullstack-test/expr/empty_input_for_udaf.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
mysql> drop table if exists test.t
mysql> create table test.t(a int not null, b int, c int, d int, e int, f int)
mysql> alter table test.t set tiflash replica 1 location labels 'rack', 'host', 'abc'

SLEEP 15

mysql> insert into test.t values (1, 1, 1, 1, 1, 1);
mysql> insert into test.t values (1, 2, 3, NULL, NULL, 1);

SLEEP 15

mysql> select /*+ read_from_storage(tiflash[t]) */ count(1),count(a),count(b),count(d),count(NULL) from test.t where a > 10;
+----------+----------+----------+----------+-------------+
| count(1) | count(a) | count(b) | count(d) | count(NULL) |
+----------+----------+----------+----------+-------------+
| 0 | 0 | 0 | 0 | 0 |
+----------+----------+----------+----------+-------------+

mysql> select /*+ read_from_storage(tiflash[t]) */ count(1),count(a),count(b),count(d),count(NULL) from test.t where a <= 10;
+----------+----------+----------+----------+-------------+
| count(1) | count(a) | count(b) | count(d) | count(NULL) |
+----------+----------+----------+----------+-------------+
| 2 | 2 | 2 | 1 | 0 |
+----------+----------+----------+----------+-------------+

mysql> select /*+ read_from_storage(tiflash[t]) */ sum(1),sum(a),sum(b),sum(d),sum(NULL) from test.t where a > 10;
+--------+--------+--------+--------+-----------+
| sum(1) | sum(a) | sum(b) | sum(d) | sum(NULL) |
+--------+--------+--------+--------+-----------+
| NULL | NULL | NULL | NULL | NULL |
+--------+--------+--------+--------+-----------+

mysql> select /*+ read_from_storage(tiflash[t]) */ sum(1),sum(a),sum(b),sum(d),sum(NULL) from test.t where a <= 10;
+--------+--------+--------+--------+-----------+
| sum(1) | sum(a) | sum(b) | sum(d) | sum(NULL) |
+--------+--------+--------+--------+-----------+
| 2 | 2 | 3 | 1 | NULL |
+--------+--------+--------+--------+-----------+

mysql> select /*+ read_from_storage(tiflash[t]) */ min(1),min(a),min(b),min(d),min(NULL) from test.t where a > 10;
+--------+--------+--------+--------+-----------+
| min(1) | min(a) | min(b) | min(d) | min(NULL) |
+--------+--------+--------+--------+-----------+
| NULL | NULL | NULL | NULL | NULL |
+--------+--------+--------+--------+-----------+

mysql> select /*+ read_from_storage(tiflash[t]) */ min(1),min(a),min(b),min(d),min(NULL) from test.t where a <= 10;
+--------+--------+--------+--------+-----------+
| min(1) | min(a) | min(b) | min(d) | min(NULL) |
+--------+--------+--------+--------+-----------+
| 1 | 1 | 1 | 1 | NULL |
+--------+--------+--------+--------+-----------+

mysql> select /*+ read_from_storage(tiflash[t]) */ max(1),max(a),max(b),max(d),max(NULL) from test.t where a > 10;
+--------+--------+--------+--------+-----------+
| max(1) | max(a) | max(b) | max(d) | max(NULL) |
+--------+--------+--------+--------+-----------+
| NULL | NULL | NULL | NULL | NULL |
+--------+--------+--------+--------+-----------+

mysql> select /*+ read_from_storage(tiflash[t]) */ max(1),max(a),max(b),max(d),max(NULL) from test.t where a <= 10;
+--------+--------+--------+--------+-----------+
| max(1) | max(a) | max(b) | max(d) | max(NULL) |
+--------+--------+--------+--------+-----------+
| 1 | 1 | 2 | 1 | NULL |
+--------+--------+--------+--------+-----------+

0 comments on commit 15b2183

Please sign in to comment.