Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Groupby1 #1

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions cpp/src/arrow/compute/api_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,24 @@ struct ARROW_EXPORT QuantileOptions : public FunctionOptions {
enum Interpolation interpolation;
};

// TODO(michalursa) add docstring
struct ARROW_EXPORT GroupByOptions : public FunctionOptions {
struct Aggregate {
/// the name of the aggregation function
std::string function;

/// options for the aggregation function
const FunctionOptions* options;

/// the name of the resulting column in output
std::string name;
};
std::vector<Aggregate> aggregates;

/// the names of key columns
std::vector<std::string> key_names;
};

/// @}

/// \brief Count non-null (or null) values in an array.
Expand Down
14 changes: 10 additions & 4 deletions cpp/src/arrow/compute/exec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -847,11 +847,17 @@ class ScalarAggExecutor : public KernelExecutorImpl<ScalarAggregateKernel> {
KernelContext batch_ctx(exec_context());
batch_ctx.SetState(batch_state.get());

kernel_->consume(&batch_ctx, batch);
ARROW_CTX_RETURN_IF_ERROR(&batch_ctx);
if (kernel_->nomerge) {
kernel_->consume(kernel_ctx_, batch);
ARROW_CTX_RETURN_IF_ERROR(kernel_ctx_);
} else {
kernel_->consume(&batch_ctx, batch);
ARROW_CTX_RETURN_IF_ERROR(&batch_ctx);

kernel_->merge(kernel_ctx_, std::move(*batch_state), state());
ARROW_CTX_RETURN_IF_ERROR(kernel_ctx_);
}

kernel_->merge(kernel_ctx_, std::move(*batch_state), state());
ARROW_CTX_RETURN_IF_ERROR(kernel_ctx_);
return Status::OK();
}

Expand Down
10 changes: 6 additions & 4 deletions cpp/src/arrow/compute/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -684,21 +684,23 @@ struct ScalarAggregateKernel : public Kernel {

ScalarAggregateKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
ScalarAggregateConsume consume, ScalarAggregateMerge merge,
ScalarAggregateFinalize finalize)
ScalarAggregateFinalize finalize, bool nomerge = false)
: Kernel(std::move(sig), init),
consume(std::move(consume)),
merge(std::move(merge)),
finalize(std::move(finalize)) {}
finalize(std::move(finalize)),
nomerge(nomerge) {}

ScalarAggregateKernel(std::vector<InputType> in_types, OutputType out_type,
KernelInit init, ScalarAggregateConsume consume,
ScalarAggregateMerge merge, ScalarAggregateFinalize finalize)
ScalarAggregateMerge merge, ScalarAggregateFinalize finalize, bool nomerge = false)
: ScalarAggregateKernel(KernelSignature::Make(std::move(in_types), out_type), init,
consume, merge, finalize) {}
consume, merge, finalize, nomerge) {}

ScalarAggregateConsume consume;
ScalarAggregateMerge merge;
ScalarAggregateFinalize finalize;
bool nomerge;
};

} // namespace compute
Expand Down
228 changes: 226 additions & 2 deletions cpp/src/arrow/compute/kernels/aggregate_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

#include <map>
#include "arrow/compute/api_aggregate.h"
#include "arrow/compute/kernels/aggregate_basic_internal.h"
#include "arrow/compute/kernels/aggregate_internal.h"
Expand Down Expand Up @@ -42,9 +43,10 @@ void AggregateFinalize(KernelContext* ctx, Datum* out) {
} // namespace

void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
ScalarAggregateFunction* func, SimdLevel::type simd_level) {
ScalarAggregateFunction* func, SimdLevel::type simd_level,
bool nomerge) {
ScalarAggregateKernel kernel(std::move(sig), init, AggregateConsume, AggregateMerge,
AggregateFinalize);
AggregateFinalize, nomerge);
// Set the simd level
kernel.simd_level = simd_level;
DCHECK_OK(func->AddKernel(kernel));
Expand Down Expand Up @@ -91,6 +93,104 @@ struct CountImpl : public ScalarAggregator {
int64_t nulls = 0;
};

struct GroupedAggregator {
// GroupedAggregator subclasses are expected to be constructible from
// const FunctionOptions*. Will probably need an Init method as well
virtual ~GroupedAggregator() = default;

virtual void Consume(KernelContext*, const ExecBatch& batch,
const uint32_t* group_ids) = 0;

virtual void Finalize(KernelContext* ctx, Datum* out) = 0;

static Result<std::unique_ptr<GroupedAggregator>> Make(std::string function,
const FunctionOptions* options);
};

struct GroupedCountImpl : public GroupedAggregator {
explicit GroupedCountImpl(const FunctionOptions* options)
: options(checked_cast<const CountOptions&>(*options)) {}

void Consume(KernelContext* ctx, const ExecBatch& batch,
const uint32_t* group_ids) override {
if (batch.length == 0) return;

// maybe a batch of group_ids should include the min/max group id
auto max_group = *std::max_element(group_ids, group_ids + batch.length);
if (max_group >= counts.size()) {
counts.resize(max_group + 1, 0);
}

if (options.count_mode == CountOptions::COUNT_NON_NULL) {
auto input = batch[0].make_array();

for (int64_t i = 0; i < input->length(); ++i) {
if (input->IsNull(i)) continue;
counts[group_ids[i]]++;
}
} else {
for (int64_t i = 0; i < batch.length; ++i) {
counts[group_ids[i]]++;
}
}
}

void Finalize(KernelContext* ctx, Datum* out) override {
KERNEL_ASSIGN_OR_RAISE(auto counts_buf, ctx,
ctx->Allocate(sizeof(int64_t) * counts.size()));
std::copy(counts.begin(), counts.end(),
reinterpret_cast<int64_t*>(counts_buf->mutable_data()));
*out = std::make_shared<Int64Array>(counts.size(), std::move(counts_buf));
}

CountOptions options;
std::vector<int64_t> counts;
};

struct GroupedSumImpl : public GroupedAggregator {
explicit GroupedSumImpl(const FunctionOptions*) {}

void Consume(KernelContext* ctx, const ExecBatch& batch,
const uint32_t* group_ids) override {
if (batch.length == 0) return;

// maybe a batch of group_ids should include the min/max group id
auto max_group = *std::max_element(group_ids, group_ids + batch.length);
if (max_group >= sums.size()) {
sums.resize(max_group + 1, 0.0);
}

DCHECK_EQ(batch[0].type()->id(), Type::DOUBLE);
auto input = batch[0].array_as<DoubleArray>();

for (int64_t i = 0; i < input->length(); ++i) {
if (input->IsNull(i)) continue;
sums[group_ids[i]] += input->Value(i);
}
}

void Finalize(KernelContext* ctx, Datum* out) override {
KERNEL_ASSIGN_OR_RAISE(auto sums_buf, ctx,
ctx->Allocate(sizeof(double) * sums.size()));
std::copy(sums.begin(), sums.end(),
reinterpret_cast<double*>(sums_buf->mutable_data()));
*out = std::make_shared<DoubleArray>(sums.size(), std::move(sums_buf));
}

std::vector<double> sums;
};

Result<std::unique_ptr<GroupedAggregator>> GroupedAggregator::Make(
std::string function, const FunctionOptions* options) {
if (function == "count") {
return ::arrow::internal::make_unique<GroupedCountImpl>(options);
}
if (function == "sum") {
return ::arrow::internal::make_unique<GroupedSumImpl>(options);
}
return Status::NotImplemented("Grouped aggregate ", function);
}

std::unique_ptr<KernelState> CountInit(KernelContext*, const KernelInitArgs& args) {
return ::arrow::internal::make_unique<CountImpl>(
static_cast<const CountOptions&>(*args.options));
Expand Down Expand Up @@ -229,6 +329,114 @@ std::unique_ptr<KernelState> AllInit(KernelContext*, const KernelInitArgs& args)
return ::arrow::internal::make_unique<BooleanAllImpl>();
}

struct GroupByImpl : public ScalarAggregator {
void Consume(KernelContext* ctx, const ExecBatch& batch) override {
ArrayDataVector aggregands, keys;

size_t i;
for (i = 0; i < aggregators.size(); ++i) {
aggregands.push_back(batch[i].array());
}
while (i < static_cast<size_t>(batch.num_values())) {
keys.push_back(batch[i++].array());
}

auto key64 = batch[aggregators.size()].array_as<Int64Array>();
if (key64->null_count() != 0) {
ctx->SetStatus(Status::NotImplemented("nulls in key column"));
return;
}

const int64_t* key64_raw = key64->raw_values();

std::vector<uint32_t> group_ids(batch.length);
for (int64_t i = 0; i < batch.length; ++i) {
uint64_t key = key64_raw[i];
auto iter = map_.find(key);
if (iter == map_.end()) {
group_ids[i] = static_cast<uint32_t>(keys_.size());
keys_.push_back(key);
map_.insert(std::make_pair(key, group_ids[i]));
} else {
group_ids[i] = iter->second;
}
}

for (size_t i = 0; i < aggregators.size(); ++i) {
ExecBatch aggregand_batch{{aggregands[i]}, batch.length};
aggregators[i]->Consume(ctx, aggregand_batch, group_ids.data());
if (ctx->HasError()) return;
}
}

void MergeFrom(KernelContext* ctx, KernelState&& src) override {
// TODO(michalursa) merge two hash tables
}

void Finalize(KernelContext* ctx, Datum* out) override {
FieldVector out_fields(aggregators.size() + 1);
ArrayDataVector out_columns(aggregators.size() + 1);
for (size_t i = 0; i < aggregators.size(); ++i) {
Datum aggregand;
aggregators[i]->Finalize(ctx, &aggregand);
if (ctx->HasError()) return;
out_columns[i] = aggregand.array();
out_fields[i] = field(options.aggregates[i].name, aggregand.type());
}

int64_t length = keys_.size();
KERNEL_ASSIGN_OR_RAISE(auto key_buf, ctx, ctx->Allocate(sizeof(int64_t) * length));
std::copy(keys_.begin(), keys_.end(),
reinterpret_cast<int64_t*>(key_buf->mutable_data()));
auto key = std::make_shared<Int64Array>(length, std::move(key_buf));

out_columns.back() = key->data();
out_fields.back() = field(options.key_names[0], key->type());

*out = ArrayData::Make(struct_(std::move(out_fields)), key->length(),
{/*null_bitmap=*/nullptr}, std::move(out_columns));
}

std::map<uint64_t, uint32_t> map_;
std::vector<uint64_t> keys_;

GroupByOptions options;
std::vector<std::unique_ptr<GroupedAggregator>> aggregators;
};

std::unique_ptr<KernelState> GroupByInit(KernelContext* ctx, const KernelInitArgs& args) {
// TODO(michalursa) do construction of group by implementation
auto impl = ::arrow::internal::make_unique<GroupByImpl>();
impl->options = *checked_cast<const GroupByOptions*>(args.options);
const auto& aggregates = impl->options.aggregates;

if (aggregates.size() > args.inputs.size()) {
ctx->SetStatus(Status::Invalid("more aggegates than inputs!"));
return nullptr;
}

impl->aggregators.resize(aggregates.size());
for (size_t i = 0; i < aggregates.size(); ++i) {
ctx->SetStatus(GroupedAggregator::Make(aggregates[i].function, aggregates[i].options)
.Value(&impl->aggregators[i]));
if (ctx->HasError()) return nullptr;
}

size_t n_keys = args.inputs.size() - aggregates.size();
if (n_keys != 1) {
ctx->SetStatus(Status::NotImplemented("more than one key"));
return nullptr;
}

if (args.inputs.back().type->id() != Type::INT64) {
ctx->SetStatus(
Status::NotImplemented("key of type", args.inputs.back().type->ToString()));
return nullptr;
}

return impl;
}

void AddBasicAggKernels(KernelInit init,
const std::vector<std::shared_ptr<DataType>>& types,
std::shared_ptr<DataType> out_ty, ScalarAggregateFunction* func,
Expand Down Expand Up @@ -286,6 +494,9 @@ const FunctionDoc all_doc{
("Null values are ignored."),
{"array"}};

// TODO(michalursa) add FunctionDoc for group_by
const FunctionDoc group_by_doc{"", (""), {}};

} // namespace

void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
Expand Down Expand Up @@ -368,6 +579,19 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
func = std::make_shared<ScalarAggregateFunction>("all", Arity::Unary(), &all_doc);
aggregate::AddBasicAggKernels(aggregate::AllInit, {boolean()}, boolean(), func.get());
DCHECK_OK(registry->AddFunction(std::move(func)));

// group_by
func = std::make_shared<ScalarAggregateFunction>("group_by", Arity::VarArgs(),
&group_by_doc);
// aggregate::AddBasicAggKernels(aggregate::GroupByInit, {null()}, null(), func.get());
{
InputType any_array(ValueDescr::ARRAY);
auto sig = KernelSignature::Make({any_array}, ValueDescr::Array(int64()), true);
AddAggKernel(std::move(sig), aggregate::GroupByInit, func.get(), SimdLevel::NONE,
true);
}
DCHECK_OK(registry->AddFunction(std::move(func)));
// TODO(michalursa) add Kernels to the function named "group_by"
}

} // namespace internal
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/kernels/aggregate_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct ScalarAggregator : public KernelState {

void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
ScalarAggregateFunction* func,
SimdLevel::type simd_level = SimdLevel::NONE);
SimdLevel::type simd_level = SimdLevel::NONE, bool nomerge = false);

} // namespace compute
} // namespace arrow
Loading