Skip to content

Commit

Permalink
ARROW-11929: [C++][Dataset][Compute] Promote expression to the comput…
Browse files Browse the repository at this point in the history
…e namespace

Moves Expression and its test and benchmark into the compute/exec/ directory. I haven't introduced an exec namespace.

Closes apache#10166 from bkietz/11929-Promote-Expression-to-the

Authored-by: Benjamin Kietzman <[email protected]>
Signed-off-by: David Li <[email protected]>
  • Loading branch information
bkietz authored and lidavidm committed Apr 30, 2021
1 parent c501761 commit 7430bbd
Show file tree
Hide file tree
Showing 45 changed files with 473 additions and 368 deletions.
21 changes: 11 additions & 10 deletions cpp/examples/arrow/dataset_documentation_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@

#include <arrow/api.h>
#include <arrow/compute/cast.h>
#include <arrow/compute/exec/expression.h>
#include <arrow/dataset/dataset.h>
#include <arrow/dataset/discovery.h>
#include <arrow/dataset/expression.h>
#include <arrow/dataset/file_base.h>
#include <arrow/dataset/file_ipc.h>
#include <arrow/dataset/file_parquet.h>
Expand All @@ -37,6 +37,7 @@

namespace ds = arrow::dataset;
namespace fs = arrow::fs;
namespace cp = arrow::compute;

#define ABORT_ON_FAILURE(expr) \
do { \
Expand Down Expand Up @@ -185,7 +186,7 @@ std::shared_ptr<arrow::Table> FilterAndSelectDataset(
// Read specified columns with a row filter
auto scan_builder = dataset->NewScan().ValueOrDie();
ABORT_ON_FAILURE(scan_builder->Project({"b"}));
ABORT_ON_FAILURE(scan_builder->Filter(ds::less(ds::field_ref("b"), ds::literal(4))));
ABORT_ON_FAILURE(scan_builder->Filter(cp::less(cp::field_ref("b"), cp::literal(4))));
auto scanner = scan_builder->Finish().ValueOrDie();
return scanner->ToTable().ValueOrDie();
}
Expand All @@ -210,12 +211,12 @@ std::shared_ptr<arrow::Table> ProjectDataset(
ABORT_ON_FAILURE(scan_builder->Project(
{
// Leave column "a" as-is.
ds::field_ref("a"),
cp::field_ref("a"),
// Cast column "b" to float32.
ds::call("cast", {ds::field_ref("b")},
cp::call("cast", {cp::field_ref("b")},
arrow::compute::CastOptions::Safe(arrow::float32())),
// Derive a boolean column from "c".
ds::equal(ds::field_ref("c"), ds::literal(1)),
cp::equal(cp::field_ref("c"), cp::literal(1)),
},
{"a_renamed", "b_as_float32", "c_1"}));
auto scanner = scan_builder->Finish().ValueOrDie();
Expand All @@ -239,15 +240,15 @@ std::shared_ptr<arrow::Table> SelectAndProjectDataset(
// Read specified columns with a row filter
auto scan_builder = dataset->NewScan().ValueOrDie();
std::vector<std::string> names;
std::vector<ds::Expression> exprs;
std::vector<cp::Expression> exprs;
// Read all the original columns.
for (const auto& field : dataset->schema()->fields()) {
names.push_back(field->name());
exprs.push_back(ds::field_ref(field->name()));
exprs.push_back(cp::field_ref(field->name()));
}
// Also derive a new column.
names.push_back("b_large");
exprs.push_back(ds::greater(ds::field_ref("b"), ds::literal(1)));
names.emplace_back("b_large");
exprs.push_back(cp::greater(cp::field_ref("b"), cp::literal(1)));
ABORT_ON_FAILURE(scan_builder->Project(exprs, names));
auto scanner = scan_builder->Finish().ValueOrDie();
return scanner->ToTable().ValueOrDie();
Expand Down Expand Up @@ -295,7 +296,7 @@ std::shared_ptr<arrow::Table> FilterPartitionedDataset(
// Filter based on the partition values. This will mean that we won't even read the
// files whose partition expressions don't match the filter.
ABORT_ON_FAILURE(
scan_builder->Filter(ds::equal(ds::field_ref("part"), ds::literal("b"))));
scan_builder->Filter(cp::equal(cp::field_ref("part"), cp::literal("b"))));
auto scanner = scan_builder->Finish().ValueOrDie();
return scanner->ToTable().ValueOrDie();
}
Expand Down
10 changes: 6 additions & 4 deletions cpp/examples/arrow/dataset_parquet_scan_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
// under the License.

#include <arrow/api.h>
#include <arrow/compute/exec/expression.h>
#include <arrow/dataset/dataset.h>
#include <arrow/dataset/discovery.h>
#include <arrow/dataset/expression.h>
#include <arrow/dataset/file_base.h>
#include <arrow/dataset/file_parquet.h>
#include <arrow/dataset/scanner.h>
Expand All @@ -37,6 +37,8 @@ namespace fs = arrow::fs;

namespace ds = arrow::dataset;

namespace cp = arrow::compute;

#define ABORT_ON_FAILURE(expr) \
do { \
arrow::Status status_ = (expr); \
Expand All @@ -60,8 +62,8 @@ struct Configuration {

// Indicates the filter by which rows will be filtered. This optimization can
// make use of partition information and/or file metadata if possible.
ds::Expression filter =
ds::greater(ds::field_ref("total_amount"), ds::literal(1000.0f));
cp::Expression filter =
cp::greater(cp::field_ref("total_amount"), cp::literal(1000.0f));

ds::InspectOptions inspect_options{};
ds::FinishOptions finish_options{};
Expand Down Expand Up @@ -146,7 +148,7 @@ std::shared_ptr<ds::Dataset> GetDatasetFromPath(

std::shared_ptr<ds::Scanner> GetScannerFromDataset(std::shared_ptr<ds::Dataset> dataset,
std::vector<std::string> columns,
ds::Expression filter,
cp::Expression filter,
bool use_threads) {
auto scanner_builder = dataset->NewScan().ValueOrDie();

Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ if(ARROW_COMPUTE)
compute/api_vector.cc
compute/cast.cc
compute/exec.cc
compute/exec/expression.cc
compute/function.cc
compute/kernel.cc
compute/registry.cc
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/compute/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,5 @@ add_arrow_compute_test(internals_test
add_arrow_benchmark(function_benchmark PREFIX "arrow-compute")

add_subdirectory(kernels)

add_subdirectory(exec)
22 changes: 22 additions & 0 deletions cpp/src/arrow/compute/exec/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

arrow_install_all_headers("arrow/compute/exec")

add_arrow_compute_test(expression_test PREFIX "arrow-compute")

add_arrow_benchmark(expression_benchmark PREFIX "arrow-compute")
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
// specific language governing permissions and limitations
// under the License.

#include "arrow/dataset/expression.h"
#include "arrow/compute/exec/expression.h"

#include <unordered_map>
#include <unordered_set>

#include "arrow/chunked_array.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/exec/expression_internal.h"
#include "arrow/compute/exec_internal.h"
#include "arrow/dataset/expression_internal.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/reader.h"
#include "arrow/ipc/writer.h"
Expand All @@ -39,7 +39,7 @@ namespace arrow {
using internal::checked_cast;
using internal::checked_pointer_cast;

namespace dataset {
namespace compute {

Expression::Expression(Call call) : impl_(std::make_shared<Impl>(std::move(call))) {}

Expand Down Expand Up @@ -198,7 +198,7 @@ std::string Expression::ToString() const {

if (auto options = GetStrptimeOptions(*call)) {
return out + "format=" + options->format +
", unit=" + internal::ToString(options->unit) + ")";
", unit=" + arrow::internal::ToString(options->unit) + ")";
}

return out + "{NON-REPRESENTABLE OPTIONS})";
Expand Down Expand Up @@ -304,8 +304,9 @@ size_t Expression::hash() const {
}

std::shared_ptr<std::atomic<size_t>> expected = nullptr;
internal::atomic_compare_exchange_strong(&const_cast<Call*>(call)->hash, &expected,
std::make_shared<std::atomic<size_t>>(out));
::arrow::internal::atomic_compare_exchange_strong(
&const_cast<Call*>(call)->hash, &expected,
std::make_shared<std::atomic<size_t>>(out));
return out;
}

Expand Down Expand Up @@ -525,6 +526,23 @@ Result<Datum> ExecuteScalarExpression(const Expression& expr, const Datum& input
"ExecuteScalarExpression cannot Execute non-scalar expression ", expr.ToString());
}

if (input.kind() == Datum::TABLE) {
TableBatchReader reader(*input.table());
std::shared_ptr<RecordBatch> batch;

while (true) {
RETURN_NOT_OK(reader.ReadNext(&batch));
if (batch != nullptr) {
break;
}
ARROW_ASSIGN_OR_RAISE(Datum res, ExecuteScalarExpression(expr, batch));
if (res.is_scalar()) {
ARROW_ASSIGN_OR_RAISE(res, MakeArrayFromScalar(*res.scalar(), batch->num_rows(),
exec_context->memory_pool()));
}
}
}

if (auto lit = expr.literal()) return *lit;

if (auto ref = expr.field_ref()) {
Expand Down Expand Up @@ -1156,7 +1174,8 @@ Result<Expression> Deserialize(std::shared_ptr<Buffer> buffer) {

Result<std::shared_ptr<Scalar>> GetScalar(const std::string& i) {
int32_t column_index;
if (!internal::ParseValue<Int32Type>(i.data(), i.length(), &column_index)) {
if (!::arrow::internal::ParseValue<Int32Type>(i.data(), i.length(),
&column_index)) {
return Status::Invalid("Couldn't parse column_index");
}
if (column_index >= batch_.num_columns()) {
Expand Down Expand Up @@ -1279,5 +1298,5 @@ Expression operator||(Expression lhs, Expression rhs) {
return or_(std::move(lhs), std::move(rhs));
}

} // namespace dataset
} // namespace compute
} // namespace arrow
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
#include "arrow/util/variant.h"

namespace arrow {
namespace dataset {
namespace compute {

/// An unbound expression which maps a single Datum to another Datum.
/// An expression is one of
/// - A literal Datum.
/// - A reference to a single (potentially nested) field of the input Datum.
/// - A call to a compute function, with arguments specified by other Expressions.
class ARROW_DS_EXPORT Expression {
class ARROW_EXPORT Expression {
public:
struct Call {
std::string function_name;
Expand Down Expand Up @@ -122,28 +122,28 @@ class ARROW_DS_EXPORT Expression {
using Impl = util::Variant<Datum, Parameter, Call>;
std::shared_ptr<Impl> impl_;

ARROW_DS_EXPORT friend bool Identical(const Expression& l, const Expression& r);
ARROW_EXPORT friend bool Identical(const Expression& l, const Expression& r);

ARROW_DS_EXPORT friend void PrintTo(const Expression&, std::ostream*);
ARROW_EXPORT friend void PrintTo(const Expression&, std::ostream*);
};

inline bool operator==(const Expression& l, const Expression& r) { return l.Equals(r); }
inline bool operator!=(const Expression& l, const Expression& r) { return !l.Equals(r); }

// Factories

ARROW_DS_EXPORT
ARROW_EXPORT
Expression literal(Datum lit);

template <typename Arg>
Expression literal(Arg&& arg) {
return literal(Datum(std::forward<Arg>(arg)));
}

ARROW_DS_EXPORT
ARROW_EXPORT
Expression field_ref(FieldRef ref);

ARROW_DS_EXPORT
ARROW_EXPORT
Expression call(std::string function, std::vector<Expression> arguments,
std::shared_ptr<compute::FunctionOptions> options = NULLPTR);

Expand All @@ -156,11 +156,11 @@ Expression call(std::string function, std::vector<Expression> arguments,
}

/// Assemble a list of all fields referenced by an Expression at any depth.
ARROW_DS_EXPORT
ARROW_EXPORT
std::vector<FieldRef> FieldsInExpression(const Expression&);

/// Assemble a mapping from field references to known values.
ARROW_DS_EXPORT
ARROW_EXPORT
Result<std::unordered_map<FieldRef, Datum, FieldRef::Hash>> ExtractKnownFieldValues(
const Expression& guaranteed_true_predicate);

Expand All @@ -179,25 +179,25 @@ Result<std::unordered_map<FieldRef, Datum, FieldRef::Hash>> ExtractKnownFieldVal
/// Weak canonicalization which establishes guarantees for subsequent passes. Even
/// equivalent Expressions may result in different canonicalized expressions.
/// TODO this could be a strong canonicalization
ARROW_DS_EXPORT
ARROW_EXPORT
Result<Expression> Canonicalize(Expression, compute::ExecContext* = NULLPTR);

/// Simplify Expressions based on literal arguments (for example, add(null, x) will always
/// be null so replace the call with a null literal). Includes early evaluation of all
/// calls whose arguments are entirely literal.
ARROW_DS_EXPORT
ARROW_EXPORT
Result<Expression> FoldConstants(Expression);

/// Simplify Expressions by replacing with known values of the fields which it references.
ARROW_DS_EXPORT
ARROW_EXPORT
Result<Expression> ReplaceFieldsWithKnownValues(
const std::unordered_map<FieldRef, Datum, FieldRef::Hash>& known_values, Expression);

/// Simplify an expression by replacing subexpressions based on a guarantee:
/// a boolean expression which is guaranteed to evaluate to `true`. For example, this is
/// used to remove redundant function calls from a filter expression or to replace a
/// reference to a constant-value field with a literal.
ARROW_DS_EXPORT
ARROW_EXPORT
Result<Expression> SimplifyWithGuarantee(Expression,
const Expression& guaranteed_true_predicate);

Expand All @@ -207,44 +207,44 @@ Result<Expression> SimplifyWithGuarantee(Expression,

/// Execute a scalar expression against the provided state and input Datum. This
/// expression must be bound.
ARROW_DS_EXPORT
ARROW_EXPORT
Result<Datum> ExecuteScalarExpression(const Expression&, const Datum& input,
compute::ExecContext* = NULLPTR);

// Serialization

ARROW_DS_EXPORT
ARROW_EXPORT
Result<std::shared_ptr<Buffer>> Serialize(const Expression&);

ARROW_DS_EXPORT
ARROW_EXPORT
Result<Expression> Deserialize(std::shared_ptr<Buffer>);

// Convenience aliases for factories

ARROW_DS_EXPORT Expression project(std::vector<Expression> values,
std::vector<std::string> names);
ARROW_EXPORT Expression project(std::vector<Expression> values,
std::vector<std::string> names);

ARROW_DS_EXPORT Expression equal(Expression lhs, Expression rhs);
ARROW_EXPORT Expression equal(Expression lhs, Expression rhs);

ARROW_DS_EXPORT Expression not_equal(Expression lhs, Expression rhs);
ARROW_EXPORT Expression not_equal(Expression lhs, Expression rhs);

ARROW_DS_EXPORT Expression less(Expression lhs, Expression rhs);
ARROW_EXPORT Expression less(Expression lhs, Expression rhs);

ARROW_DS_EXPORT Expression less_equal(Expression lhs, Expression rhs);
ARROW_EXPORT Expression less_equal(Expression lhs, Expression rhs);

ARROW_DS_EXPORT Expression greater(Expression lhs, Expression rhs);
ARROW_EXPORT Expression greater(Expression lhs, Expression rhs);

ARROW_DS_EXPORT Expression greater_equal(Expression lhs, Expression rhs);
ARROW_EXPORT Expression greater_equal(Expression lhs, Expression rhs);

ARROW_DS_EXPORT Expression is_null(Expression lhs);
ARROW_EXPORT Expression is_null(Expression lhs);

ARROW_DS_EXPORT Expression is_valid(Expression lhs);
ARROW_EXPORT Expression is_valid(Expression lhs);

ARROW_DS_EXPORT Expression and_(Expression lhs, Expression rhs);
ARROW_DS_EXPORT Expression and_(const std::vector<Expression>&);
ARROW_DS_EXPORT Expression or_(Expression lhs, Expression rhs);
ARROW_DS_EXPORT Expression or_(const std::vector<Expression>&);
ARROW_DS_EXPORT Expression not_(Expression operand);
ARROW_EXPORT Expression and_(Expression lhs, Expression rhs);
ARROW_EXPORT Expression and_(const std::vector<Expression>&);
ARROW_EXPORT Expression or_(Expression lhs, Expression rhs);
ARROW_EXPORT Expression or_(const std::vector<Expression>&);
ARROW_EXPORT Expression not_(Expression operand);

} // namespace dataset
} // namespace compute
} // namespace arrow
Loading

0 comments on commit 7430bbd

Please sign in to comment.