Skip to content

Commit

Permalink
[REFACTOR][ARITH] Remove legacy compute_expr.h (apache#5738)
Browse files Browse the repository at this point in the history
Replaces most of the ComptuteReduce using foldl.
  • Loading branch information
tqchen authored and Trevor Morris committed Jun 18, 2020
1 parent fc67c8a commit 3c9d909
Show file tree
Hide file tree
Showing 23 changed files with 70 additions and 157 deletions.
25 changes: 25 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ TVM_DLL PrimExpr isinf(PrimExpr x);
* \brief sum of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
* \return The result.
*/
TVM_DLL PrimExpr sum(PrimExpr source, Array<tir::IterVar> axis);

Expand All @@ -477,27 +478,31 @@ TVM_DLL PrimExpr all(PrimExpr source, Array<tir::IterVar> axis);
* \brief logical Or of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
* \return The result.
*/
TVM_DLL PrimExpr any(PrimExpr source, Array<tir::IterVar> axis);

/*!
* \brief max of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
* \return The result.
*/
TVM_DLL PrimExpr max(PrimExpr source, Array<tir::IterVar> axis);

/*!
* \brief max of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
* \return The result.
*/
TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis);

/*!
* \brief product of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
* \return The result.
*/
TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis);

Expand Down Expand Up @@ -658,6 +663,17 @@ inline bool is_zero(const PrimExpr& x) { return is_const_int(x, 0); }
*/
inline bool is_const(const PrimExpr& x);

/*!
* \brief Left fold.
* \param freduce The reduction function.
* \param init_value The initial value.
* \param values The values to be folded.
* \return The result.
* \tparam FReduce The type of the reduction.
*/
template <typename FReduce>
inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr>& values);

/*!
* \brief Check whether x is a constant power of two
* If x is power of two, write the power to the shift.
Expand Down Expand Up @@ -762,6 +778,15 @@ inline PrimExpr make_zero(DataType t) {
}
return make_const(t, 0);
}

template <typename FReduce>
inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr>& values) {
for (PrimExpr val : values) {
init_value = freduce(init_value, val);
}
return init_value;
}

} // namespace tir

// additional const expression overloading
Expand Down
109 changes: 0 additions & 109 deletions src/arith/compute_expr.h

This file was deleted.

1 change: 0 additions & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@

#include <vector>

#include "../../../arith/compute_expr.h"
#include "../../transforms/infer_layout_util.h"
#include "../../transforms/pattern_util.h"
#include "../op_common.h"
Expand Down
1 change: 0 additions & 1 deletion src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
#include <utility>
#include <vector>

#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
#include "../../tir/transforms/ir_util.h"
#include "llvm_common.h"
Expand Down
1 change: 0 additions & 1 deletion src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include <cctype>
#include <iomanip>

#include "../../arith/compute_expr.h"
#include "../../arith/pattern_match.h"

namespace tvm {
Expand Down
3 changes: 1 addition & 2 deletions src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@

#include <tvm/runtime/container.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>

#include <string>

#include "../../arith/compute_expr.h"

namespace tvm {
namespace codegen {

Expand Down
5 changes: 2 additions & 3 deletions src/te/operation/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
#include <unordered_set>
#include <utility>

#include "../../arith/compute_expr.h"
#include "../../arith/interval_set.h"
#include "../schedule/message_passing.h"
#include "op_util.h"
Expand Down Expand Up @@ -593,8 +592,8 @@ Stmt TransformUpdate(const Stage& stage, const std::unordered_map<IterVar, Range
}
}

return IfThenElseNode::make(arith::ComputeReduce<tir::OrNode>(conds, const_true(1)), update,
body);
auto cond = foldl([](PrimExpr a, PrimExpr b) { return a || b; }, const_false(1), conds);
return IfThenElseNode::make(cond, update, body);
}

} // namespace te
Expand Down
1 change: 0 additions & 1 deletion src/te/operation/op_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

#include <string>

#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
#include "../schedule/message_passing.h"

Expand Down
1 change: 0 additions & 1 deletion src/te/operation/tensor_compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

#include <unordered_set>

#include "../../arith/compute_expr.h"
#include "./compute_op.h"
#include "./op_util.h"

Expand Down
2 changes: 0 additions & 2 deletions src/te/schedule/message_passing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>

#include "../../arith/compute_expr.h"

namespace tvm {
namespace te {

Expand Down
13 changes: 8 additions & 5 deletions src/te/schedule/schedule_dataflow_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
*/
#include <tvm/te/operation.h>
#include <tvm/te/schedule.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

#include <unordered_set>

#include "../../arith/compute_expr.h"
#include "../../tir/transforms/ir_util.h"
#include "message_passing.h"
#include "operation_inline.h"
Expand Down Expand Up @@ -89,13 +89,14 @@ PrimExpr InjectPredicate(const Array<PrimExpr>& predicates, PrimExpr body) {
using tir::SelectNode;
if (predicates.size() == 0) return body;
const ReduceNode* reduce = body.as<ReduceNode>();
auto fand = [](PrimExpr a, PrimExpr b) { return a && b; };

if (reduce) {
auto n = make_object<ReduceNode>(*reduce);
n->condition = n->condition && arith::ComputeReduce<tir::AndNode>(predicates, PrimExpr());
n->condition = foldl(fand, n->condition, predicates);
return PrimExpr(n);
}
return SelectNode::make(arith::ComputeReduce<tir::AndNode>(predicates, PrimExpr()), body,
make_zero(body.dtype()));
return SelectNode::make(foldl(fand, const_true(1), predicates), body, make_zero(body.dtype()));
}

// Replace data flow appears in all stages given the tensor change.
Expand Down Expand Up @@ -707,7 +708,9 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int f
const ReduceNode* reduce = compute_op->body[idx].as<ReduceNode>();
CHECK(reduce) << "Can only rfactor non-inline reductions";
predicates.push_back(reduce->condition);
PrimExpr predicate = likely(arith::ComputeReduce<tir::AndNode>(predicates, PrimExpr()));
auto fand = [](PrimExpr a, PrimExpr b) { return a && b; };

PrimExpr predicate = likely(foldl(fand, const_true(1), predicates));

std::unordered_map<const VarNode*, PrimExpr> vsub;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@

#include <unordered_map>

#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"

namespace tvm {
Expand Down
6 changes: 3 additions & 3 deletions src/tir/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@
#include <tvm/tir/analysis.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>

#include <iterator>
#include <stack>

#include "../../arith/compute_expr.h"

namespace tvm {
namespace tir {

Expand Down Expand Up @@ -367,7 +366,8 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane
int highest_dim = 0;
extent = self->strides[highest_dim] * self->shape[highest_dim] - offset;
} else {
extent = arith::ComputeReduce<tir::MulNode>(self->shape, PrimExpr()) - offset;
auto fmul = [](PrimExpr a, PrimExpr b) { return a * b; };
extent = foldl(fmul, make_const(DataType::Int(32), 1), self->shape) - offset;
}
PrimExpr elem_offset = self->elem_offset + offset;
if (content_lanes > 1) {
Expand Down
7 changes: 4 additions & 3 deletions src/tir/transforms/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

#include <tvm/runtime/device_api.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>

#include "../../arith/compute_expr.h"
#include "ir_util.h"

namespace tvm {
Expand Down Expand Up @@ -225,8 +225,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
<< " expected to be compact array";
if (conds.size() != 0) {
auto stride_msg = tvm::tir::StringImmNode::make(stride_err_msg.str());
Stmt check = AssertStmtNode::make(arith::ComputeReduce<tir::AndNode>(conds, PrimExpr()),
stride_msg, EvaluateNode::make(0));
auto fand = [](PrimExpr a, PrimExpr b) { return a && b; };
Stmt check = AssertStmtNode::make(foldl(fand, const_true(1), conds), stride_msg,
EvaluateNode::make(0));
check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt());
asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)}));
}
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/inject_double_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../../arith/compute_expr.h"
#include "ir_util.h"

namespace tvm {
Expand Down Expand Up @@ -115,8 +114,9 @@ class DoubleBufferInjector : public StmtExprMutator {
Stmt VisitStmt_(const AllocateNode* op) final {
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
auto fmul = [](PrimExpr a, PrimExpr b) { return a * b; };
it->second.stride =
arith::ComputeReduce<MulNode>(op->extents, PrimExpr()) * op->dtype.lanes();
foldl(fmul, make_const(DataType::Int(32), 1), op->extents) * op->dtype.lanes();
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
Array<PrimExpr> new_extents{make_const(op->extents[0].dtype(), 2)};
Expand Down
5 changes: 3 additions & 2 deletions src/tir/transforms/inject_virtual_thread.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

#include <unordered_set>

#include "../../arith/compute_expr.h"
#include "ir_util.h"

namespace tvm {
Expand Down Expand Up @@ -368,7 +367,9 @@ class VTInjector : public StmtExprMutator {
// always rewrite if not allow sharing.
if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
// place v on highest dimension.
PrimExpr stride = arith::ComputeReduce<MulNode>(op->extents, PrimExpr()) * op->dtype.lanes();
auto fmul = [](PrimExpr a, PrimExpr b) { return a * b; };
PrimExpr stride =
foldl(fmul, make_const(DataType::Int(32), 1), op->extents) * op->dtype.lanes();
Array<PrimExpr> other;
other.push_back(make_const(op->extents[0].dtype(), num_threads_));
for (PrimExpr e : extents) {
Expand Down
Loading

0 comments on commit 3c9d909

Please sign in to comment.