diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index a4748d525829..8d2add2b896a 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -463,6 +463,7 @@ TVM_DLL PrimExpr isinf(PrimExpr x); * \brief sum of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. + * \return The result. */ TVM_DLL PrimExpr sum(PrimExpr source, Array axis); @@ -477,6 +478,7 @@ TVM_DLL PrimExpr all(PrimExpr source, Array axis); * \brief logical Or of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. + * \return The result. */ TVM_DLL PrimExpr any(PrimExpr source, Array axis); @@ -484,6 +486,7 @@ TVM_DLL PrimExpr any(PrimExpr source, Array axis); * \brief max of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. + * \return The result. */ TVM_DLL PrimExpr max(PrimExpr source, Array axis); @@ -491,6 +494,7 @@ TVM_DLL PrimExpr max(PrimExpr source, Array axis); * \brief max of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. + * \return The result. */ TVM_DLL PrimExpr min(PrimExpr source, Array axis); @@ -498,6 +502,7 @@ TVM_DLL PrimExpr min(PrimExpr source, Array axis); * \brief product of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. + * \return The result. */ TVM_DLL PrimExpr prod(PrimExpr source, Array axis); @@ -658,6 +663,17 @@ inline bool is_zero(const PrimExpr& x) { return is_const_int(x, 0); } */ inline bool is_const(const PrimExpr& x); +/*! + * \brief Left fold. + * \param freduce The reduction function. + * \param init_value The initial value. + * \param values The values to be folded. + * \return The result. + * \tparam FReduce The type of the reduction. + */ +template +inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array& values); + /*! * \brief Check whether x is a constant power of two * If x is power of two, write the power to the shift. @@ -762,6 +778,15 @@ inline PrimExpr make_zero(DataType t) { } return make_const(t, 0); } + +template +inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array& values) { + for (PrimExpr val : values) { + init_value = freduce(init_value, val); + } + return init_value; +} + } // namespace tir // additional const expression overloading diff --git a/src/arith/compute_expr.h b/src/arith/compute_expr.h deleted file mode 100644 index 39530ff9d49f..000000000000 --- a/src/arith/compute_expr.h +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file compute_expr.h - * \brief Utility to invoke certan compute operations. - */ -#ifndef TVM_ARITH_COMPUTE_EXPR_H_ -#define TVM_ARITH_COMPUTE_EXPR_H_ - -#include -#include - -#include -#include - -namespace tvm { -namespace arith { - -/*! - * \brief Compute the expression with the given binary op. - * \param lhs The left operand - * \param rhs The right operand - * \tparam Op the computation operator - * \return The result. - */ -template -inline PrimExpr Compute(PrimExpr lhs, PrimExpr rhs) { - return OP::make(lhs, rhs); -} - -/*! - * \brief Compute an reduction with Op - * \param values The input values. - * \param empty_value The value when return if it is empty, can be Expr() - * which will cause an error to be rasied. - * \tparam Op The computation operator - * \return The result. - */ -template -inline PrimExpr ComputeReduce(const Array& values, PrimExpr empty_value); - -template <> -inline PrimExpr Compute(PrimExpr a, PrimExpr b) { - return a + b; -} - -template <> -inline PrimExpr Compute(PrimExpr a, PrimExpr b) { - return a - b; -} - -template <> -inline PrimExpr Compute(PrimExpr a, PrimExpr b) { - return a * b; -} - -template <> -inline PrimExpr Compute(PrimExpr a, PrimExpr b) { - return truncdiv(a, b); -} - -template <> -inline PrimExpr Compute(PrimExpr a, PrimExpr b) { - return truncmod(a, b); -} - -template <> -inline PrimExpr Compute(PrimExpr a, PrimExpr b) { - return max(a, b); -} - -template <> -inline PrimExpr Compute(PrimExpr a, PrimExpr b) { - return min(a, b); -} - -template -inline PrimExpr ComputeReduce(const Array& values, PrimExpr empty_value) { - if (values.size() == 0U) { - CHECK(empty_value.defined()); - return empty_value; - } - PrimExpr res = values[0]; - for (size_t i = 1; i < values.size(); ++i) { - res = Compute(res, values[i]); - } - return res; -} - -} // namespace arith -} // namespace tvm -#endif // TVM_ARITH_COMPUTE_EXPR_H_ diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index a80bb31774e3..0275a893e1f4 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -38,7 +38,6 @@ #include -#include "../../../arith/compute_expr.h" #include "../../transforms/infer_layout_util.h" #include "../../transforms/pattern_util.h" #include "../op_common.h" diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 4522c150b39b..0bca2a169ba4 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -42,7 +42,6 @@ #include #include -#include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" #include "../../tir/transforms/ir_util.h" #include "llvm_common.h" diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 69dab6238225..9255d7c80c46 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -25,7 +25,6 @@ #include #include -#include "../../arith/compute_expr.h" #include "../../arith/pattern_match.h" namespace tvm { diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index e76e8bee81fc..364a62fa0e3e 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -25,11 +25,10 @@ #include #include +#include #include -#include "../../arith/compute_expr.h" - namespace tvm { namespace codegen { diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 048285d83360..cc843a48650e 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -34,7 +34,6 @@ #include #include -#include "../../arith/compute_expr.h" #include "../../arith/interval_set.h" #include "../schedule/message_passing.h" #include "op_util.h" @@ -593,8 +592,8 @@ Stmt TransformUpdate(const Stage& stage, const std::unordered_map(conds, const_true(1)), update, - body); + auto cond = foldl([](PrimExpr a, PrimExpr b) { return a || b; }, const_false(1), conds); + return IfThenElseNode::make(cond, update, body); } } // namespace te diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc index 5b200ac0ce94..341e7610463b 100644 --- a/src/te/operation/op_util.cc +++ b/src/te/operation/op_util.cc @@ -29,7 +29,6 @@ #include -#include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" #include "../schedule/message_passing.h" diff --git a/src/te/operation/tensor_compute_op.cc b/src/te/operation/tensor_compute_op.cc index 236aff68b44c..96ddb36afb7a 100644 --- a/src/te/operation/tensor_compute_op.cc +++ b/src/te/operation/tensor_compute_op.cc @@ -29,7 +29,6 @@ #include -#include "../../arith/compute_expr.h" #include "./compute_op.h" #include "./op_util.h" diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 4f0e98243a02..55593be34212 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -26,8 +26,6 @@ #include #include -#include "../../arith/compute_expr.h" - namespace tvm { namespace te { diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index 95612635a3d9..cfd8b26d95be 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -22,11 +22,11 @@ */ #include #include +#include #include #include -#include "../../arith/compute_expr.h" #include "../../tir/transforms/ir_util.h" #include "message_passing.h" #include "operation_inline.h" @@ -89,13 +89,14 @@ PrimExpr InjectPredicate(const Array& predicates, PrimExpr body) { using tir::SelectNode; if (predicates.size() == 0) return body; const ReduceNode* reduce = body.as(); + auto fand = [](PrimExpr a, PrimExpr b) { return a && b; }; + if (reduce) { auto n = make_object(*reduce); - n->condition = n->condition && arith::ComputeReduce(predicates, PrimExpr()); + n->condition = foldl(fand, n->condition, predicates); return PrimExpr(n); } - return SelectNode::make(arith::ComputeReduce(predicates, PrimExpr()), body, - make_zero(body.dtype())); + return SelectNode::make(foldl(fand, const_true(1), predicates), body, make_zero(body.dtype())); } // Replace data flow appears in all stages given the tensor change. @@ -707,7 +708,9 @@ Array Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int f const ReduceNode* reduce = compute_op->body[idx].as(); CHECK(reduce) << "Can only rfactor non-inline reductions"; predicates.push_back(reduce->condition); - PrimExpr predicate = likely(arith::ComputeReduce(predicates, PrimExpr())); + auto fand = [](PrimExpr a, PrimExpr b) { return a && b; }; + + PrimExpr predicate = likely(foldl(fand, const_true(1), predicates)); std::unordered_map vsub; diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index 84166d11881b..46fc91befd47 100644 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -36,7 +36,6 @@ #include -#include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" namespace tvm { diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 8b98ed9d14d9..3a60521dfb9b 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -26,12 +26,11 @@ #include #include #include +#include #include #include -#include "../../arith/compute_expr.h" - namespace tvm { namespace tir { @@ -367,7 +366,8 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane int highest_dim = 0; extent = self->strides[highest_dim] * self->shape[highest_dim] - offset; } else { - extent = arith::ComputeReduce(self->shape, PrimExpr()) - offset; + auto fmul = [](PrimExpr a, PrimExpr b) { return a * b; }; + extent = foldl(fmul, make_const(DataType::Int(32), 1), self->shape) - offset; } PrimExpr elem_offset = self->elem_offset + offset; if (content_lanes > 1) { diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index 01a69969b489..384d4593f06d 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -25,8 +25,8 @@ #include #include +#include -#include "../../arith/compute_expr.h" #include "ir_util.h" namespace tvm { @@ -225,8 +225,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, << " expected to be compact array"; if (conds.size() != 0) { auto stride_msg = tvm::tir::StringImmNode::make(stride_err_msg.str()); - Stmt check = AssertStmtNode::make(arith::ComputeReduce(conds, PrimExpr()), - stride_msg, EvaluateNode::make(0)); + auto fand = [](PrimExpr a, PrimExpr b) { return a && b; }; + Stmt check = AssertStmtNode::make(foldl(fand, const_true(1), conds), stride_msg, + EvaluateNode::make(0)); check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt()); asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)})); } diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index ae5e673d8b7e..c405b1f5a679 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -26,7 +26,6 @@ #include #include -#include "../../arith/compute_expr.h" #include "ir_util.h" namespace tvm { @@ -115,8 +114,9 @@ class DoubleBufferInjector : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { auto it = dbuffer_info_.find(op->buffer_var.get()); if (it != dbuffer_info_.end()) { + auto fmul = [](PrimExpr a, PrimExpr b) { return a * b; }; it->second.stride = - arith::ComputeReduce(op->extents, PrimExpr()) * op->dtype.lanes(); + foldl(fmul, make_const(DataType::Int(32), 1), op->extents) * op->dtype.lanes(); Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); Array new_extents{make_const(op->extents[0].dtype(), 2)}; diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index 834a7e908f76..e2a027d2f1f4 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -27,7 +27,6 @@ #include -#include "../../arith/compute_expr.h" #include "ir_util.h" namespace tvm { @@ -368,7 +367,9 @@ class VTInjector : public StmtExprMutator { // always rewrite if not allow sharing. if (touched_var_.count(op->buffer_var.get()) || !allow_share_) { // place v on highest dimension. - PrimExpr stride = arith::ComputeReduce(op->extents, PrimExpr()) * op->dtype.lanes(); + auto fmul = [](PrimExpr a, PrimExpr b) { return a * b; }; + PrimExpr stride = + foldl(fmul, make_const(DataType::Int(32), 1), op->extents) * op->dtype.lanes(); Array other; other.push_back(make_const(op->extents[0].dtype(), num_threads_)); for (PrimExpr e : extents) { diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index de86647535f6..7f9a3291c80b 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -30,7 +30,6 @@ #include -#include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" #include "ir_util.h" diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 91879b6a4b82..fb86bc20da48 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -31,12 +31,12 @@ #include #include #include +#include #include #include #include -#include "../../arith/compute_expr.h" #include "../../arith/pattern_match.h" #include "../../runtime/thread_storage_scope.h" diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index cd749b9ced81..3a4213785a16 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -23,11 +23,11 @@ #include "storage_access.h" #include +#include #include #include -#include "../../arith/compute_expr.h" #include "ir_util.h" namespace tvm { diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 646e00855c2b..1e656ce05296 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -37,7 +37,6 @@ #include -#include "../../arith/compute_expr.h" #include "../../arith/ir_visitor_with_analyzer.h" #include "../../runtime/thread_storage_scope.h" #include "arg_binder.h" diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index fc86f2bdf348..365ff75d03f9 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -34,7 +34,6 @@ #include #include -#include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" #include "ir_util.h" @@ -555,10 +554,12 @@ class StoragePlanRewriter : public StmtExprMutator { alloc_type = op->dtype; } } + + auto fmul = [](PrimExpr a, PrimExpr b) { return a * b; }; + if (e->allocs.size() == 1) { // simply use the original allocation. - PrimExpr sz = arith::ComputeReduce(e->allocs[0]->extents, - make_const(DataType::Int(32), 1)); + PrimExpr sz = foldl(fmul, make_const(DataType::Int(32), 1), e->allocs[0]->extents); e->new_alloc = AllocateNode::make(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, EvaluateNode::make(0)); if (e->scope.tag.length() != 0) { @@ -571,8 +572,7 @@ class StoragePlanRewriter : public StmtExprMutator { // Build a merged allocation PrimExpr combo_size; for (const AllocateNode* op : e->allocs) { - PrimExpr sz = - arith::ComputeReduce(op->extents, make_const(DataType::Int(32), 1)); + PrimExpr sz = foldl(fmul, make_const(DataType::Int(32), 1), op->extents); auto nbits = op->dtype.bits() * op->dtype.lanes(); if (const auto* imm = sz.as()) { if (imm->value > std::numeric_limits::max() / nbits) { diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index 4ccfbc3840b5..fd1a92a70b69 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -33,7 +33,6 @@ #include #include -#include "../../arith/compute_expr.h" #include "ir_util.h" namespace tvm { diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 9e553cb12ceb..91993ac3b1ee 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -31,8 +32,6 @@ #include #include -#include "../../arith/compute_expr.h" - namespace tvm { namespace tir { @@ -109,8 +108,14 @@ class Vectorizer : public StmtExprMutator { } } - PrimExpr VisitExpr_(const AddNode* op) final { return AddSubVec(op); } - PrimExpr VisitExpr_(const SubNode* op) final { return AddSubVec(op); } + PrimExpr VisitExpr_(const AddNode* op) final { + return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a + b; }); + } + + PrimExpr VisitExpr_(const SubNode* op) final { + return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a - b; }); + } + PrimExpr VisitExpr_(const MulNode* op) final { PrimExpr a = this->VisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); @@ -423,8 +428,8 @@ class Vectorizer : public StmtExprMutator { return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); } } - template - PrimExpr AddSubVec(const T* op) { + template + PrimExpr AddSubVec(const T* op, FCompute fcompute) { PrimExpr a = this->VisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { @@ -435,12 +440,12 @@ class Vectorizer : public StmtExprMutator { const RampNode* b_ramp = b.as(); const RampNode* a_ramp = a.as(); if (a.dtype().lanes() == 1 && b_ramp) { - return RampNode::make( - arith::Compute(a, b_ramp->base), - arith::Compute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), b_ramp->lanes); + return RampNode::make(fcompute(a, b_ramp->base), + fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), + b_ramp->lanes); } if (b.dtype().lanes() == 1 && a_ramp) { - return RampNode::make(arith::Compute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); + return RampNode::make(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); } } return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));