Skip to content

Commit

Permalink
Issue duckdb#10023: Approx_Count_Distinct Memory Usage
Browse files Browse the repository at this point in the history
Convert approx_count_distinct to use
a 64 byte fixed memory allocation.
This reduces the accuracy a bit
but allows the function state to be paged.

The reported query now runs in 2GB of RAM with paging.

fixes: duckdb#10023
fixes: duckdblabs/duckdb-internal#1081
  • Loading branch information
hawkfish committed Jun 3, 2024
1 parent 362d09e commit 0c5c721
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 68 deletions.
180 changes: 122 additions & 58 deletions src/core_functions/aggregate/distributive/approx_count.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,109 +5,165 @@
#include "duckdb/function/function_set.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"

#include "hyperloglog.hpp"

namespace duckdb {

// Algorithms from
// "New cardinality estimation algorithms for HyperLogLog sketches"
// Otmar Ertl, arXiv:1702.01284
struct ApproxDistinctCountState {
ApproxDistinctCountState() : log(nullptr) {
static constexpr idx_t P = 6;
static constexpr idx_t Q = 64 - P;
static constexpr idx_t M = 1 << P;
static constexpr double ALPHA = 0.721347520444481703680; // 1 / (2 log(2))

ApproxDistinctCountState() {
::memset(k, 0, sizeof(k));
}

//! Taken from https://stackoverflow.com/a/72088344
static inline uint8_t CountTrailingZeros(const uint64_t &x) {
static constexpr const uint64_t DEBRUIJN = 0x03f79d71b4cb0a89;
static constexpr const uint8_t LOOKUP[] = {0, 47, 1, 56, 48, 27, 2, 60, 57, 49, 41, 37, 28, 16, 3, 61,
54, 58, 35, 52, 50, 42, 21, 44, 38, 32, 29, 23, 17, 11, 4, 62,
46, 55, 26, 59, 40, 36, 15, 53, 34, 51, 20, 43, 31, 22, 10, 45,
25, 39, 14, 33, 19, 30, 9, 24, 13, 18, 8, 12, 7, 6, 5, 63};
return LOOKUP[(DEBRUIJN * (x ^ (x - 1))) >> 58];
}

inline void Update(const idx_t &i, const uint8_t &z) {
k[i] = MaxValue<uint8_t>(k[i], z);
}

//! Algorithm 1
inline void InsertElement(hash_t h) {
const auto i = h & ((1 << P) - 1);
h >>= P;
h |= hash_t(1) << Q;
const uint8_t z = CountTrailingZeros(h) + 1;
Update(i, z);
}

//! Algorithm 2
inline void Merge(const ApproxDistinctCountState &other) {
for (idx_t i = 0; i < M; ++i) {
Update(i, other.k[i]);
}
}

//! Algorithm 4
void ExtractCounts(uint32_t *c) const {
for (idx_t i = 0; i < M; ++i) {
c[k[i]]++;
}
}
~ApproxDistinctCountState() {
if (log) {
delete log;

//! Algorithm 6
static int64_t EstimateCardinality(uint32_t *c) {
auto z = M * duckdb_hll::hllTau((M - c[Q]) / double(M));

for (idx_t k = Q; k >= 1; --k) {
z += c[k];
z *= 0.5;
}

z += M * duckdb_hll::hllSigma(c[0] / double(M));

return llroundl(ALPHA * M * M / z);
}

HyperLogLog *log;
idx_t Count() const {
uint32_t c[Q + 2] = {0};
ExtractCounts(c);
return idx_t(EstimateCardinality(c));
}

uint8_t k[M];
};

struct ApproxCountDistinctFunction {
template <class STATE>
static void Initialize(STATE &state) {
state.log = nullptr;
new (&state) STATE();
}

template <class STATE, class OP>
static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
if (!source.log) {
return;
}
if (!target.log) {
target.log = new HyperLogLog();
}
D_ASSERT(target.log);
D_ASSERT(source.log);
auto new_log = target.log->MergePointer(*source.log);
delete target.log;
target.log = new_log;
target.Merge(source);
}

template <class T, class STATE>
static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
if (state.log) {
target = UnsafeNumericCast<T>(state.log->Count());
} else {
target = 0;
}
target = UnsafeNumericCast<T>(state.Count());
}

static bool IgnoreNull() {
return true;
}

template <class STATE>
static void Destroy(STATE &state, AggregateInputData &aggr_input_data) {
if (state.log) {
delete state.log;
state.log = nullptr;
}
}
};

static void ApproxCountDistinctSimpleUpdateFunction(Vector inputs[], AggregateInputData &, idx_t input_count,
data_ptr_t state, idx_t count) {
D_ASSERT(input_count == 1);
auto &input = inputs[0];
UnifiedVectorFormat idata;
input.ToUnifiedFormat(count, idata);

auto agg_state = reinterpret_cast<ApproxDistinctCountState *>(state);
if (!agg_state->log) {
agg_state->log = new HyperLogLog();
if (count > STANDARD_VECTOR_SIZE) {
throw InternalException("ApproxCountDistinct - count must be at most vector size");
}
hash_t hash_array[STANDARD_VECTOR_SIZE];
Vector hash_vec(LogicalType::HASH, reinterpret_cast<data_ptr_t>(&hash_array));
VectorOperations::Hash(input, hash_vec, count);

UnifiedVectorFormat vdata;
inputs[0].ToUnifiedFormat(count, vdata);
UnifiedVectorFormat hdata;
hash_vec.ToUnifiedFormat(count, hdata);
const auto *hashes = UnifiedVectorFormat::GetData<hash_t>(hdata);
auto agg_state = reinterpret_cast<ApproxDistinctCountState *>(state);

if (count > STANDARD_VECTOR_SIZE) {
throw InternalException("ApproxCountDistinct - count must be at most vector size");
if (hash_vec.GetVectorType() == VectorType::CONSTANT_VECTOR) {
if (idata.validity.RowIsValid(0)) {
agg_state->InsertElement(hashes[0]);
}
} else {
for (idx_t i = 0; i < count; ++i) {
if (idata.validity.RowIsValid(idata.sel->get_index(i))) {
const auto hash = hashes[hdata.sel->get_index(i)];
agg_state->InsertElement(hash);
}
}
}
uint64_t indices[STANDARD_VECTOR_SIZE];
uint8_t counts[STANDARD_VECTOR_SIZE];
HyperLogLog::ProcessEntries(vdata, inputs[0].GetType(), indices, counts, count);
agg_state->log->AddToLog(vdata, count, indices, counts);
}

static void ApproxCountDistinctUpdateFunction(Vector inputs[], AggregateInputData &, idx_t input_count,
Vector &state_vector, idx_t count) {
D_ASSERT(input_count == 1);
auto &input = inputs[0];
UnifiedVectorFormat idata;
input.ToUnifiedFormat(count, idata);

if (count > STANDARD_VECTOR_SIZE) {
throw InternalException("ApproxCountDistinct - count must be at most vector size");
}
hash_t hash_array[STANDARD_VECTOR_SIZE];
Vector hash_vec(LogicalType::HASH, reinterpret_cast<data_ptr_t>(&hash_array));
VectorOperations::Hash(input, hash_vec, count);

UnifiedVectorFormat sdata;
state_vector.ToUnifiedFormat(count, sdata);
auto states = UnifiedVectorFormat::GetDataNoConst<ApproxDistinctCountState *>(sdata);
const auto states = UnifiedVectorFormat::GetDataNoConst<ApproxDistinctCountState *>(sdata);

UnifiedVectorFormat hdata;
hash_vec.ToUnifiedFormat(count, hdata);
const auto *hashes = UnifiedVectorFormat::GetData<hash_t>(hdata);
for (idx_t i = 0; i < count; i++) {
auto agg_state = states[sdata.sel->get_index(i)];
if (!agg_state->log) {
agg_state->log = new HyperLogLog();
if (idata.validity.RowIsValid(idata.sel->get_index(i))) {
auto agg_state = states[sdata.sel->get_index(i)];
const auto hash = hashes[hdata.sel->get_index(i)];
agg_state->InsertElement(hash);
}
}

UnifiedVectorFormat vdata;
inputs[0].ToUnifiedFormat(count, vdata);

if (count > STANDARD_VECTOR_SIZE) {
throw InternalException("ApproxCountDistinct - count must be at most vector size");
}
uint64_t indices[STANDARD_VECTOR_SIZE];
uint8_t counts[STANDARD_VECTOR_SIZE];
HyperLogLog::ProcessEntries(vdata, inputs[0].GetType(), indices, counts, count);
HyperLogLog::AddToLogs(vdata, count, indices, counts, reinterpret_cast<HyperLogLog ***>(states), sdata.sel);
}

AggregateFunction GetApproxCountDistinctFunction(const LogicalType &input_type) {
Expand All @@ -117,8 +173,7 @@ AggregateFunction GetApproxCountDistinctFunction(const LogicalType &input_type)
ApproxCountDistinctUpdateFunction,
AggregateFunction::StateCombine<ApproxDistinctCountState, ApproxCountDistinctFunction>,
AggregateFunction::StateFinalize<ApproxDistinctCountState, int64_t, ApproxCountDistinctFunction>,
ApproxCountDistinctSimpleUpdateFunction, nullptr,
AggregateFunction::StateDestroy<ApproxDistinctCountState, ApproxCountDistinctFunction>);
ApproxCountDistinctSimpleUpdateFunction);
fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
return fun;
}
Expand All @@ -130,14 +185,23 @@ AggregateFunctionSet ApproxCountDistinctFun::GetFunctions() {
approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::UINTEGER));
approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::UBIGINT));
approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::UHUGEINT));

approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::TINYINT));
approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::SMALLINT));
approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::INTEGER));
approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::BIGINT));
approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::HUGEINT));

approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::FLOAT));
approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::DOUBLE));

approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::DATE));
approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::TIME));
approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::TIME_TZ));
approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::TIMESTAMP));
approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::TIMESTAMP_TZ));
approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::INTERVAL));

approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::BLOB));
approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::ANY_PARAMS(LogicalType::VARCHAR, 150)));
return approx_count;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ create table t as select range a, mod(range,10) b from range(2000);
query III
SELECT COUNT( a),approx_count_distinct(a),approx_count_distinct(b) from t
----
2000 1991 10
2000 2322 11

query I
SELECT approx_count_distinct(a) from t group by a %2 order by all;
----
986
993
1006
1230

query I
SELECT count(*) from t where a < 10;
Expand All @@ -92,7 +92,7 @@ SELECT approx_count_distinct(a) over (partition by a%2) from t where a < 10;
query II
SELECT COUNT( t),approx_count_distinct(t) from timestamp
----
7 7
7 6

query II
SELECT COUNT( t),approx_count_distinct(t) from dates
Expand Down
8 changes: 4 additions & 4 deletions test/sql/function/list/aggregates/approx_count_distinct.test
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ INSERT INTO timestamp VALUES (['2008-01-01 00:00:01', NULL, '2007-01-01 00:00:01
query II
SELECT list_count(t), list_approx_count_distinct(t) from timestamp
----
7 7
7 6

# strings
statement ok
Expand All @@ -80,7 +80,7 @@ INSERT INTO list_ints_2 SELECT LIST(a), LIST(mod(a, 10)) FROM range(2000) tbl(a)
query III
SELECT list_count(a), list_approx_count_distinct(a), list_approx_count_distinct(b) from list_ints_2
----
2000 1991 10
2000 2322 11

statement ok
DELETE FROM list_ints_2
Expand All @@ -94,5 +94,5 @@ INSERT INTO list_ints_2 SELECT LIST(a), NULL FROM range(2000) tbl(a, b) WHERE a
query I
SELECT list_approx_count_distinct(a) from list_ints_2;
----
993
986
1006
1230
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ from bool;
query II rowsort
select approx_count_distinct(a), approx_count_distinct(b) from t group by b%2;
----
49874 5
51026 5
41234 5
50630 5

query II
select arg_min(b,a), arg_max(b,a) from t;
Expand Down
4 changes: 4 additions & 0 deletions third_party/hyperloglog/hyperloglog.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ robj *hll_merge(robj **hlls, size_t hll_count);
//! Get size (in bytes) of the HLL
uint64_t get_size();

//! Helper Functions
double hllSigma(double x);
double hllTau(double x);

uint64_t MurmurHash64A(const void *key, int len, unsigned int seed);

// NOLINTEND
Expand Down

0 comments on commit 0c5c721

Please sign in to comment.