Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-11929: [C++][Dataset][Compute] Promote expression to the compute namespace #10166

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions cpp/examples/arrow/dataset_documentation_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@

#include <arrow/api.h>
#include <arrow/compute/cast.h>
#include <arrow/compute/exec/expression.h>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll follow up and adjust the line numbers in the corresponding reST file (and see if I can figure out a better way to excerpt code snippets than hardcoding line numbers).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#include <arrow/dataset/dataset.h>
#include <arrow/dataset/discovery.h>
#include <arrow/dataset/expression.h>
#include <arrow/dataset/file_base.h>
#include <arrow/dataset/file_ipc.h>
#include <arrow/dataset/file_parquet.h>
Expand All @@ -37,6 +37,7 @@

namespace ds = arrow::dataset;
namespace fs = arrow::fs;
namespace cp = arrow::compute;

#define ABORT_ON_FAILURE(expr) \
do { \
Expand Down Expand Up @@ -185,7 +186,7 @@ std::shared_ptr<arrow::Table> FilterAndSelectDataset(
// Read specified columns with a row filter
auto scan_builder = dataset->NewScan().ValueOrDie();
ABORT_ON_FAILURE(scan_builder->Project({"b"}));
ABORT_ON_FAILURE(scan_builder->Filter(ds::less(ds::field_ref("b"), ds::literal(4))));
ABORT_ON_FAILURE(scan_builder->Filter(cp::less(cp::field_ref("b"), cp::literal(4))));
auto scanner = scan_builder->Finish().ValueOrDie();
return scanner->ToTable().ValueOrDie();
}
Expand All @@ -210,12 +211,12 @@ std::shared_ptr<arrow::Table> ProjectDataset(
ABORT_ON_FAILURE(scan_builder->Project(
{
// Leave column "a" as-is.
ds::field_ref("a"),
cp::field_ref("a"),
// Cast column "b" to float32.
ds::call("cast", {ds::field_ref("b")},
cp::call("cast", {cp::field_ref("b")},
arrow::compute::CastOptions::Safe(arrow::float32())),
// Derive a boolean column from "c".
ds::equal(ds::field_ref("c"), ds::literal(1)),
cp::equal(cp::field_ref("c"), cp::literal(1)),
},
{"a_renamed", "b_as_float32", "c_1"}));
auto scanner = scan_builder->Finish().ValueOrDie();
Expand All @@ -239,15 +240,15 @@ std::shared_ptr<arrow::Table> SelectAndProjectDataset(
// Read specified columns with a row filter
auto scan_builder = dataset->NewScan().ValueOrDie();
std::vector<std::string> names;
std::vector<ds::Expression> exprs;
std::vector<cp::Expression> exprs;
// Read all the original columns.
for (const auto& field : dataset->schema()->fields()) {
names.push_back(field->name());
exprs.push_back(ds::field_ref(field->name()));
exprs.push_back(cp::field_ref(field->name()));
}
// Also derive a new column.
names.push_back("b_large");
exprs.push_back(ds::greater(ds::field_ref("b"), ds::literal(1)));
names.emplace_back("b_large");
exprs.push_back(cp::greater(cp::field_ref("b"), cp::literal(1)));
ABORT_ON_FAILURE(scan_builder->Project(exprs, names));
auto scanner = scan_builder->Finish().ValueOrDie();
return scanner->ToTable().ValueOrDie();
Expand Down Expand Up @@ -295,7 +296,7 @@ std::shared_ptr<arrow::Table> FilterPartitionedDataset(
// Filter based on the partition values. This will mean that we won't even read the
// files whose partition expressions don't match the filter.
ABORT_ON_FAILURE(
scan_builder->Filter(ds::equal(ds::field_ref("part"), ds::literal("b"))));
scan_builder->Filter(cp::equal(cp::field_ref("part"), cp::literal("b"))));
auto scanner = scan_builder->Finish().ValueOrDie();
return scanner->ToTable().ValueOrDie();
}
Expand Down
10 changes: 6 additions & 4 deletions cpp/examples/arrow/dataset_parquet_scan_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
// under the License.

#include <arrow/api.h>
#include <arrow/compute/exec/expression.h>
#include <arrow/dataset/dataset.h>
#include <arrow/dataset/discovery.h>
#include <arrow/dataset/expression.h>
#include <arrow/dataset/file_base.h>
#include <arrow/dataset/file_parquet.h>
#include <arrow/dataset/scanner.h>
Expand All @@ -37,6 +37,8 @@ namespace fs = arrow::fs;

namespace ds = arrow::dataset;

namespace cp = arrow::compute;

#define ABORT_ON_FAILURE(expr) \
do { \
arrow::Status status_ = (expr); \
Expand All @@ -60,8 +62,8 @@ struct Configuration {

// Indicates the filter by which rows will be filtered. This optimization can
// make use of partition information and/or file metadata if possible.
ds::Expression filter =
ds::greater(ds::field_ref("total_amount"), ds::literal(1000.0f));
cp::Expression filter =
cp::greater(cp::field_ref("total_amount"), cp::literal(1000.0f));

ds::InspectOptions inspect_options{};
ds::FinishOptions finish_options{};
Expand Down Expand Up @@ -146,7 +148,7 @@ std::shared_ptr<ds::Dataset> GetDatasetFromPath(

std::shared_ptr<ds::Scanner> GetScannerFromDataset(std::shared_ptr<ds::Dataset> dataset,
std::vector<std::string> columns,
ds::Expression filter,
cp::Expression filter,
bool use_threads) {
auto scanner_builder = dataset->NewScan().ValueOrDie();

Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ if(ARROW_COMPUTE)
compute/api_vector.cc
compute/cast.cc
compute/exec.cc
compute/exec/expression.cc
compute/function.cc
compute/kernel.cc
compute/registry.cc
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/compute/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,5 @@ add_arrow_compute_test(internals_test
add_arrow_benchmark(function_benchmark PREFIX "arrow-compute")

add_subdirectory(kernels)

add_subdirectory(exec)
22 changes: 22 additions & 0 deletions cpp/src/arrow/compute/exec/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

arrow_install_all_headers("arrow/compute/exec")

add_arrow_compute_test(expression_test PREFIX "arrow-compute")

add_arrow_benchmark(expression_benchmark PREFIX "arrow-compute")
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
// specific language governing permissions and limitations
// under the License.

#include "arrow/dataset/expression.h"
#include "arrow/compute/exec/expression.h"

#include <unordered_map>
#include <unordered_set>

#include "arrow/chunked_array.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/exec/expression_internal.h"
#include "arrow/compute/exec_internal.h"
#include "arrow/dataset/expression_internal.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/reader.h"
#include "arrow/ipc/writer.h"
Expand All @@ -39,7 +39,7 @@ namespace arrow {
using internal::checked_cast;
using internal::checked_pointer_cast;

namespace dataset {
namespace compute {

Expression::Expression(Call call) : impl_(std::make_shared<Impl>(std::move(call))) {}

Expand Down Expand Up @@ -198,7 +198,7 @@ std::string Expression::ToString() const {

if (auto options = GetStrptimeOptions(*call)) {
return out + "format=" + options->format +
", unit=" + internal::ToString(options->unit) + ")";
", unit=" + arrow::internal::ToString(options->unit) + ")";
}

return out + "{NON-REPRESENTABLE OPTIONS})";
Expand Down Expand Up @@ -304,8 +304,9 @@ size_t Expression::hash() const {
}

std::shared_ptr<std::atomic<size_t>> expected = nullptr;
internal::atomic_compare_exchange_strong(&const_cast<Call*>(call)->hash, &expected,
std::make_shared<std::atomic<size_t>>(out));
::arrow::internal::atomic_compare_exchange_strong(
&const_cast<Call*>(call)->hash, &expected,
std::make_shared<std::atomic<size_t>>(out));
return out;
}

Expand Down Expand Up @@ -525,6 +526,23 @@ Result<Datum> ExecuteScalarExpression(const Expression& expr, const Datum& input
"ExecuteScalarExpression cannot Execute non-scalar expression ", expr.ToString());
}

if (input.kind() == Datum::TABLE) {
TableBatchReader reader(*input.table());
std::shared_ptr<RecordBatch> batch;

while (true) {
RETURN_NOT_OK(reader.ReadNext(&batch));
if (batch != nullptr) {
break;
}
ARROW_ASSIGN_OR_RAISE(Datum res, ExecuteScalarExpression(expr, batch));
if (res.is_scalar()) {
ARROW_ASSIGN_OR_RAISE(res, MakeArrayFromScalar(*res.scalar(), batch->num_rows(),
exec_context->memory_pool()));
}
}
}

if (auto lit = expr.literal()) return *lit;

if (auto ref = expr.field_ref()) {
Expand Down Expand Up @@ -1156,7 +1174,8 @@ Result<Expression> Deserialize(std::shared_ptr<Buffer> buffer) {

Result<std::shared_ptr<Scalar>> GetScalar(const std::string& i) {
int32_t column_index;
if (!internal::ParseValue<Int32Type>(i.data(), i.length(), &column_index)) {
if (!::arrow::internal::ParseValue<Int32Type>(i.data(), i.length(),
&column_index)) {
return Status::Invalid("Couldn't parse column_index");
}
if (column_index >= batch_.num_columns()) {
Expand Down Expand Up @@ -1279,5 +1298,5 @@ Expression operator||(Expression lhs, Expression rhs) {
return or_(std::move(lhs), std::move(rhs));
}

} // namespace dataset
} // namespace compute
} // namespace arrow
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
#include "arrow/util/variant.h"

namespace arrow {
namespace dataset {
namespace compute {

/// An unbound expression which maps a single Datum to another Datum.
/// An expression is one of
/// - A literal Datum.
/// - A reference to a single (potentially nested) field of the input Datum.
/// - A call to a compute function, with arguments specified by other Expressions.
class ARROW_DS_EXPORT Expression {
class ARROW_EXPORT Expression {
public:
struct Call {
std::string function_name;
Expand Down Expand Up @@ -122,28 +122,28 @@ class ARROW_DS_EXPORT Expression {
using Impl = util::Variant<Datum, Parameter, Call>;
std::shared_ptr<Impl> impl_;

ARROW_DS_EXPORT friend bool Identical(const Expression& l, const Expression& r);
ARROW_EXPORT friend bool Identical(const Expression& l, const Expression& r);

ARROW_DS_EXPORT friend void PrintTo(const Expression&, std::ostream*);
ARROW_EXPORT friend void PrintTo(const Expression&, std::ostream*);
};

inline bool operator==(const Expression& l, const Expression& r) { return l.Equals(r); }
inline bool operator!=(const Expression& l, const Expression& r) { return !l.Equals(r); }

// Factories

ARROW_DS_EXPORT
ARROW_EXPORT
Expression literal(Datum lit);

template <typename Arg>
Expression literal(Arg&& arg) {
return literal(Datum(std::forward<Arg>(arg)));
}

ARROW_DS_EXPORT
ARROW_EXPORT
Expression field_ref(FieldRef ref);

ARROW_DS_EXPORT
ARROW_EXPORT
Expression call(std::string function, std::vector<Expression> arguments,
std::shared_ptr<compute::FunctionOptions> options = NULLPTR);

Expand All @@ -156,11 +156,11 @@ Expression call(std::string function, std::vector<Expression> arguments,
}

/// Assemble a list of all fields referenced by an Expression at any depth.
ARROW_DS_EXPORT
ARROW_EXPORT
std::vector<FieldRef> FieldsInExpression(const Expression&);

/// Assemble a mapping from field references to known values.
ARROW_DS_EXPORT
ARROW_EXPORT
Result<std::unordered_map<FieldRef, Datum, FieldRef::Hash>> ExtractKnownFieldValues(
const Expression& guaranteed_true_predicate);

Expand All @@ -179,25 +179,25 @@ Result<std::unordered_map<FieldRef, Datum, FieldRef::Hash>> ExtractKnownFieldVal
/// Weak canonicalization which establishes guarantees for subsequent passes. Even
/// equivalent Expressions may result in different canonicalized expressions.
/// TODO this could be a strong canonicalization
ARROW_DS_EXPORT
ARROW_EXPORT
Result<Expression> Canonicalize(Expression, compute::ExecContext* = NULLPTR);

/// Simplify Expressions based on literal arguments (for example, add(null, x) will always
/// be null so replace the call with a null literal). Includes early evaluation of all
/// calls whose arguments are entirely literal.
ARROW_DS_EXPORT
ARROW_EXPORT
Result<Expression> FoldConstants(Expression);

/// Simplify Expressions by replacing with known values of the fields which it references.
ARROW_DS_EXPORT
ARROW_EXPORT
Result<Expression> ReplaceFieldsWithKnownValues(
const std::unordered_map<FieldRef, Datum, FieldRef::Hash>& known_values, Expression);

/// Simplify an expression by replacing subexpressions based on a guarantee:
/// a boolean expression which is guaranteed to evaluate to `true`. For example, this is
/// used to remove redundant function calls from a filter expression or to replace a
/// reference to a constant-value field with a literal.
ARROW_DS_EXPORT
ARROW_EXPORT
Result<Expression> SimplifyWithGuarantee(Expression,
const Expression& guaranteed_true_predicate);

Expand All @@ -207,44 +207,44 @@ Result<Expression> SimplifyWithGuarantee(Expression,

/// Execute a scalar expression against the provided state and input Datum. This
/// expression must be bound.
ARROW_DS_EXPORT
ARROW_EXPORT
Result<Datum> ExecuteScalarExpression(const Expression&, const Datum& input,
compute::ExecContext* = NULLPTR);

// Serialization

ARROW_DS_EXPORT
ARROW_EXPORT
Result<std::shared_ptr<Buffer>> Serialize(const Expression&);

ARROW_DS_EXPORT
ARROW_EXPORT
Result<Expression> Deserialize(std::shared_ptr<Buffer>);

// Convenience aliases for factories

ARROW_DS_EXPORT Expression project(std::vector<Expression> values,
std::vector<std::string> names);
ARROW_EXPORT Expression project(std::vector<Expression> values,
std::vector<std::string> names);

ARROW_DS_EXPORT Expression equal(Expression lhs, Expression rhs);
ARROW_EXPORT Expression equal(Expression lhs, Expression rhs);

ARROW_DS_EXPORT Expression not_equal(Expression lhs, Expression rhs);
ARROW_EXPORT Expression not_equal(Expression lhs, Expression rhs);

ARROW_DS_EXPORT Expression less(Expression lhs, Expression rhs);
ARROW_EXPORT Expression less(Expression lhs, Expression rhs);

ARROW_DS_EXPORT Expression less_equal(Expression lhs, Expression rhs);
ARROW_EXPORT Expression less_equal(Expression lhs, Expression rhs);

ARROW_DS_EXPORT Expression greater(Expression lhs, Expression rhs);
ARROW_EXPORT Expression greater(Expression lhs, Expression rhs);

ARROW_DS_EXPORT Expression greater_equal(Expression lhs, Expression rhs);
ARROW_EXPORT Expression greater_equal(Expression lhs, Expression rhs);

ARROW_DS_EXPORT Expression is_null(Expression lhs);
ARROW_EXPORT Expression is_null(Expression lhs);

ARROW_DS_EXPORT Expression is_valid(Expression lhs);
ARROW_EXPORT Expression is_valid(Expression lhs);

ARROW_DS_EXPORT Expression and_(Expression lhs, Expression rhs);
ARROW_DS_EXPORT Expression and_(const std::vector<Expression>&);
ARROW_DS_EXPORT Expression or_(Expression lhs, Expression rhs);
ARROW_DS_EXPORT Expression or_(const std::vector<Expression>&);
ARROW_DS_EXPORT Expression not_(Expression operand);
ARROW_EXPORT Expression and_(Expression lhs, Expression rhs);
ARROW_EXPORT Expression and_(const std::vector<Expression>&);
ARROW_EXPORT Expression or_(Expression lhs, Expression rhs);
ARROW_EXPORT Expression or_(const std::vector<Expression>&);
ARROW_EXPORT Expression not_(Expression operand);

} // namespace dataset
} // namespace compute
} // namespace arrow
Loading