diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 98edbeaceb26..a8d93bf898c4 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -834,6 +834,14 @@ TVM_DLL Pass InstrumentProfileIntrinsics(); */ TVM_DLL Pass DefaultGPUSchedule(); +/*! + * \brief This pass analyzes primfunc & eliminates branch introdued due to layout specific padding. + * It leverages from the buffer assumptions and use the information to eliminate the branch. + * \note This creates more opportunity to vectorize the code. + * \return The Pass. + */ +TVM_DLL Pass UseAssumeToReduceBranches(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index c2022b918643..d8531401d49d 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -1199,3 +1199,16 @@ def DefaultGPUSchedule(): ret: tvm.transform.Pass """ return _ffi_api.DefaultGPUSchedule() # type: ignore + + +def UseAssumeToReduceBranches(): + """This pass attempts to eliminates layout specific pad branch by overcomputing the values + for padded region. Eliminating the branch will help to vectorize code, + and improve element wise ops performance. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.UseAssumeToReduceBranches() # type: ignore diff --git a/src/tir/transforms/using_assume_to_reduce_branches.cc b/src/tir/transforms/using_assume_to_reduce_branches.cc new file mode 100644 index 000000000000..2e45bb0ff8fb --- /dev/null +++ b/src/tir/transforms/using_assume_to_reduce_branches.cc @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file using_assume_to_reduce_branches.cc + * + * \brief Attempt to remove conditional branch statements by introducing + * extra computations that do not impact the final results. Mainly + * oriented for layout specific padding related branches. + * + * \note + * 1. This pass works if the buffer assumption variable is in the branch statement. + * In case, the buffer assumption is not present in the branch statement and + * there are intermediate buffers then, inline the code. + * 2. The assumptions leveraged here should be of the form T.assume(condition_on_indices or + * buffer_equals_to_some_value) + * 3. Some part of the code are reused from the control_flow_graph.cc file which also + * handles eliminating branches in particular scenarios. + * 4. This pass currently works for op_pattern kElemWise and kBroadcast. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../../arith/constraint_extract.h" +#include "../../arith/ir_mutator_with_analyzer.h" +#include "../../arith/unwrap_vector_expr.h" +#include "simplify.h" +#include "tvm/ir/expr.h" +namespace tvm { +namespace tir { + +using namespace arith; + +class AssumeChecker : public StmtExprVisitor { + /* This class checks if the primfunc has assume statement. + If yes, then only the FuncAnanlyzerMutator class runs. This is to ensure speedup in the pass.*/ + public: + bool has_assume = false; + + void VisitStmt(const Stmt& stmt) final { + if (has_assume) { + return; + } + StmtVisitor::VisitStmt(stmt); + } + void VisitExpr_(const CallNode* op) override { + if (op->op.same_as(builtin::assume())) { + has_assume = true; + } + } +}; + +class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { + /* This class analyzes the complete primfunc. + It parses the buffer assumptions and eliminates the redundant branch + introduced due to layout specific padding by leveraging from buffer assumptions. + On eliminating the branch there are more opportunities to vectorize the code + and improve performance. + + Example: + ------------- + Prim Func Before : + for (...) + T.assume( assume_condition or A[i] == 0 ) + for (...) + out = T.if_then_else(if_then_else_condition, 0, function(A)) + # here function(A) is some function on Var A + + Prim Func After : + for (...) + T.assume( assume_condition or A[i] == 0 ) + for (...) + out = function(A) # here function(A) is some function on the Var A + -------------- + # High-level implementation details : + 1. The pass parses the assume statement and stores the relevant information. + 2. The pass tries to evaluate the then_clause and else_clause in then_condition_context + and else_condition_context. + It checks if the context of the assume statement (for condition indices and + assume_condition) is same as the context of the if_then_else statement (for condition indices + and if_then_else condition). If context is same and the expression inside if_then_else statement + is a function of the buffer assumption (eg A in above example), + then the pass substitutes the value from the buffer assumption and simplifies the expression. + 3. The pass then checks if then_clause and else_clause evaluate to same value. + If yes, then return the else_clause if we are in the then_condition_context (since then_clause + will be true in this context and if else_clause is also evaluating to true then we can directly + replace it with else_clause), similarly, we return the then_clause if we are in the + else_condition_context. + This class handles all these scenarios.*/ + + public: + using Parent = IRMutatorWithAnalyzer; + explicit ParseAssumeAndOvercompute(Analyzer* analyzer) : Parent(analyzer) {} + + private: + using Parent::VisitExpr_; + using Parent::VisitStmt; + using Parent::VisitStmt_; + + // This struct stores all the relevant data related to asssume statement + struct assume_struct { // Consider the example : T.assume(i < 14 or A[i] == 0) + PrimExpr buffer_context; // The context of the assume statement (the bound on the axis) + PrimExpr buffer_predicate; // The condition inside assume statement (i < 14) excluding + // bufferload expression (A[i] == 0) + tir::BufferLoad buffer_load; // Storing the buffer load Eg: A[i] in A[i] == 0 + PrimExpr buffer_value; // Storing the value for the buffer Eg : 0 in A[i] == 0 + Array buffer_indices; // Storing the indices of the buffer Eg : i + }; + // List of conditions in a scope + std::vector conditions_; + + // Storing all the buffer assumptions data in map + std::map map_buffer_assumption; + tir::Buffer current_bufferstorenode_name; + + struct InternalConstraintContext { + /* This stuct appends the constraint passed to it in the conditions list. + It keeps track of the bounds of the variables along with any conditions on the variables */ + InternalConstraintContext(ParseAssumeAndOvercompute* self, PrimExpr constraint) + : self(self), analyzer_context(self->analyzer_, constraint) { + old_num_constraints = self->conditions_.size(); + + auto side_effect = tir::SideEffect(constraint); + if (side_effect <= tir::CallEffectKind::kPure) { + self->conditions_.push_back(constraint); + } else if (side_effect <= tir::CallEffectKind::kReadState) { + assume = constraint; + } + + new_num_constraints = self->conditions_.size(); + } + + ~InternalConstraintContext() { + ICHECK_EQ(self->conditions_.size(), new_num_constraints) + << "Internal error: Each condition should only be popped once."; + self->conditions_.erase(self->conditions_.begin() + old_num_constraints, + self->conditions_.end()); + } + + ParseAssumeAndOvercompute* self{nullptr}; + With analyzer_context; + size_t old_num_constraints{0}; + size_t new_num_constraints{0}; + Optional assume{NullOpt}; + + // Disable default-generated copy/move assignment and constructors + InternalConstraintContext(const InternalConstraintContext&) = delete; + InternalConstraintContext& operator=(const InternalConstraintContext&) = delete; + InternalConstraintContext(InternalConstraintContext&&) = delete; + InternalConstraintContext& operator=(InternalConstraintContext&&) = delete; + }; + + PrimExpr CurrentScopePredicate() const { + /* This combines all the constraints in a scope */ + PrimExpr predicate = Bool(true); + for (const auto& condition : conditions_) { + predicate = predicate && condition; + } + return predicate; + } + + Stmt VisitStmt_(const ForNode* op) final { + /* Create and delete the scope with bind. + Add the minimum and maximum bound for the variables to the conditions_ list using + InternalConstraintContext */ + analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + InternalConstraintContext ctx1(this, op->loop_var >= op->min); + InternalConstraintContext ctx2(this, op->loop_var < op->min + op->extent); + return Parent::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) override { + if (map_buffer_assumption.find(op->buffer) != map_buffer_assumption.end()) { + PrimExpr buf_value; + /* If the cuurent context where the buffer load is present is same as + the context of the buffer assumption then, return the buffer value present in the assumption. + This will eventually replace the bufferload value in the complete expresison */ + + auto buffer_assumption = map_buffer_assumption[op->buffer]; + PrimExpr current_predicate_and_context = CurrentScopePredicate(); + PrimExpr buffer_predicate_and_context = + buffer_assumption.buffer_context && buffer_assumption.buffer_predicate; + bool current_context_and_buffer_constraint_is_same = StructuralEqual()( + current_predicate_and_context, buffer_predicate_and_context, /*map_free_vars=*/true); + + if (current_context_and_buffer_constraint_is_same) { + buf_value = buffer_assumption.buffer_value; + return buf_value; + } + } + return GetRef(op); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(Parent::VisitStmt_(op)); + + // Eliminate the builtin if_then_else statement + if (auto* call = op->value.as()) { + if (call->op.same_as(builtin::if_then_else())) { + PrimExpr cond = call->args[0]; + PrimExpr then_clause = call->args[1]; + PrimExpr else_clause = call->args[2]; + + PrimExpr then_clause_in_then_context; + PrimExpr else_clause_in_then_context; + PrimExpr then_clause_in_else_context; + PrimExpr else_clause_in_else_context; + { + // Simplifying expressions in " then context " + InternalConstraintContext then_ctx(this, cond); + // This will call the current class's appropriate VisitStmt function + then_clause_in_then_context = (*this)(then_clause); + then_clause_in_then_context = analyzer_->Simplify(then_clause_in_then_context); + + else_clause_in_then_context = (*this)(else_clause); + else_clause_in_then_context = analyzer_->Simplify(else_clause_in_then_context); + } + { + // Simplifying expressions in " else context " + InternalConstraintContext else_ctx(this, !cond); + // This will call the current class's appropriate VisitStmt function + then_clause_in_else_context = (*this)(then_clause); + then_clause_in_else_context = analyzer_->Simplify(then_clause_in_else_context); + + else_clause_in_else_context = (*this)(else_clause); + else_clause_in_else_context = analyzer_->Simplify(else_clause_in_else_context); + } + + auto n = this->CopyOnWrite(op); + if (StructuralEqual()(then_clause_in_then_context, else_clause_in_then_context)) { + n->value = analyzer_->Simplify(else_clause); + return Stmt(n); + } else if (StructuralEqual()(then_clause_in_else_context, else_clause_in_else_context)) { + n->value = analyzer_->Simplify(then_clause); + return Stmt(n); + } else { + return Parent::VisitStmt_(op); + } + } + } + return Parent::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const CallNode* op) override { + if (op->op.same_as(builtin::assume())) { + Assume(op->args[0]); + } + return Parent::VisitExpr_(op); + } + + void Assume(PrimExpr assumption) { + for (const auto& expr : arith::ExtractConstraints(assumption, false)) { + AssumeConstraintComponent(expr); + } + } + + void AssumeConstraintComponent(PrimExpr assumption) { + PrimExpr additional_predicate = Bool(true); + assume_struct buf_data; + + std::vector buffer_exprs; + for (const auto& expr : arith::ExtractComponents(assumption)) { + auto side_effect = tir::SideEffect(expr); + if (side_effect <= tir::CallEffectKind::kPure) { + // Pulling out portions of the assumption that do not depend + // on a buffer value allows the following two forms to be + // treated identically. + // + // Option 1: if i < 3: T.assume(buf[i] == value) + // Option 2: T.assume(i>=3 or buf[i] == value) + additional_predicate = additional_predicate && logical_not(expr); + } else if (side_effect == tir::CallEffectKind::kReadState) { + buffer_exprs.push_back(expr); + } else { + LOG(FATAL) << "Assumption must be pure or read-only, but contained expression " << expr + << " with side-effect \'" << side_effect << "\'"; + } + } + + additional_predicate = analyzer_->Simplify(std::move(additional_predicate)); + CHECK_EQ(buffer_exprs.size(), 1) << "T.assume must contain only a single buffer expression"; + + auto* as_equal_node = buffer_exprs[0].as(); + CHECK(as_equal_node) << "T.assume buffer constraint must be of the form 'buffer[indices] == " + "value', but received " + << assumption; + if (!as_equal_node) { + // This assumption is an inequality on a data-dependent + // conditional. Not an error for this to occur, but also not + // something that is currently supported. + return; + } + + // Parse the statement and store the desired values + // Ex: A[i]==0, load = A[i], value = 0 + tir::BufferLoad load; + PrimExpr value; + if (auto opt = as_equal_node->a.as()) { + load = opt.value(); + value = as_equal_node->b; + } else if (auto opt = as_equal_node->b.as()) { + load = opt.value(); + value = as_equal_node->a; + } else { + LOG(FATAL) << "T.assume buffer constraint must be of the form 'buffer[indices] == value'"; + } + + // Populating the assume statement predicate, buffer, value + // and the context of the assume statement + buf_data.buffer_context = CurrentScopePredicate(); + buf_data.buffer_predicate = additional_predicate; + buf_data.buffer_load = load; + buf_data.buffer_value = value; + buf_data.buffer_indices = load->indices; + for (size_t i = 0; i < load->indices.size(); i++) { + buf_data.buffer_indices.push_back(analyzer_->Simplify(load->indices[i])); + } + map_buffer_assumption[buf_data.buffer_load->buffer] = buf_data; + + auto has_side_effect = tir::SideEffect(value) > tir::CallEffectKind::kPure; + CHECK(!has_side_effect) << "Buffer value in constraint must be pure expression, but was " + << value; + if (has_side_effect) { + return; + } + } +}; + +namespace transform { + +Pass UseAssumeToReduceBranches() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + arith::Analyzer analyzer; + + // The pass runs & eliminates pad branch with overcompute only if, + // the primfunc has op_pattern defined and is an elementwise op. + // AnnotateTIROpPattern pass will set op_pattern in op attributes of the primfunc. + if (n->attrs.GetAttr("op_pattern").defined()) { + Optional opt_pattern = f->GetAttr("op_pattern"); + if (opt_pattern.defined()) { + relay::OpPatternKind pattern; + pattern = static_cast(Downcast(opt_pattern)->value); + + if (pattern == relay::OpPatternKind::kElemWise || + pattern == relay::OpPatternKind::kBroadcast) { + // If the primfunc contains assume statement then, run the mutator pass. + AssumeChecker assume_checker; + assume_checker(std::move(n->body)); + + if (assume_checker.has_assume) { + // Leverage from assume and eliminate the branch + ParseAssumeAndOvercompute func_analyzer_mutator(&analyzer); + n->body = func_analyzer_mutator(std::move(n->body)); + } + } + } + } + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.UseAssumeToReduceBranches", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.UseAssumeToReduceBranches") + .set_body_typed(UseAssumeToReduceBranches); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py b/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py new file mode 100644 index 000000000000..b8ff2b6c79b2 --- /dev/null +++ b/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py @@ -0,0 +1,648 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, unused-variable + +# The test attempts to eliminate redundant pad branch and overcompute the value for elementwise ops. +# This helps to expose more opportunities to vectorize the code. + +import tvm +import tvm.testing + +import tvm.script +from tvm.script import tir as T, relax as R + + +@tvm.script.ir_module +class AddBefore: + @T.prim_func(private=True) + def add( + a: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + b: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + compute: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + ): + T.func_attr( + { + "op_attrs": {"lhs_axis": 0, "op_name": "qnn.add", "rhs_axis": 0}, + "op_pattern": 0, + "operator_name": "add", + "tir.noalias": T.bool(True), + } + ) + # with T.block("root"): + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_A_assumptions"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + not ( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5 + ) + or a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_B_assumptions"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + not ( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5 + ) + or b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("compute"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads( + a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + T.writes(compute[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + compute[ + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 + ] = T.if_then_else( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5, + T.uint8(0), + a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + + b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + + @R.function + def main( + a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): + out = R.call_tir( + AddBefore.add, + (a, b), + out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + ) + return out + + +@tvm.script.ir_module +class AddExpected: + @T.prim_func(private=True) + def add( + a: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + b: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + compute: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + ): + T.func_attr( + { + "op_attrs": {"lhs_axis": 0, "op_name": "qnn.add", "rhs_axis": 0}, + "op_pattern": 0, + "operator_name": "add", + "tir.noalias": T.bool(True), + } + ) + # with T.block("root"): + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_A_assumptions"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) + and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) + or a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_B_assumptions"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) + and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) + or b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5_0 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(2) + ): + for axis5_1_axis6_fused in T.vectorized(T.int64(128)): + with T.block("compute"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap( + "SSSS", [axis1, axis2, axis3, axis4] + ) + v_axis5 = T.axis.spatial( + T.int64(8), axis5_0 * T.int64(4) + axis5_1_axis6_fused // T.int64(32) + ) + v_axis6 = T.axis.spatial(T.int64(32), axis5_1_axis6_fused % T.int64(32)) + T.reads( + a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + T.writes( + compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + ) + compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] = ( + a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + + b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + ) + + @R.function + def main( + a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): + out = R.call_tir( + AddExpected.add, + (a, b), + out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + ) + return out + + +@tvm.script.ir_module +class SubBefore: + @T.prim_func(private=True) + def sub( + a: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + b: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + compute: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + ): + T.func_attr( + { + "op_attrs": {"lhs_axis": 0, "op_name": "qnn.subtract", "rhs_axis": 0}, + "op_pattern": 0, + "operator_name": "sub", + "tir.noalias": T.bool(True), + } + ) + # with T.block("root"): + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_A_assumptions"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + not ( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5 + ) + or a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_B_assumptions"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + not ( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5 + ) + or b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("compute"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads( + a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + T.writes(compute[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + compute[ + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 + ] = T.if_then_else( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5, + T.uint8(0), + a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + - b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + + @R.function + def main( + a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): + out = R.call_tir( + SubBefore.sub, + (a, b), + out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + ) + return out + + +@tvm.script.ir_module +class SubExpected: + @T.prim_func(private=True) + def sub( + a: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + b: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + compute: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + ): + T.func_attr( + { + "op_attrs": {"lhs_axis": 0, "op_name": "qnn.subtract", "rhs_axis": 0}, + "op_pattern": 0, + "operator_name": "sub", + "tir.noalias": T.bool(True), + } + ) + # with T.block("root"): + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_A_assumptions"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) + and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) + or a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_B_assumptions"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) + and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) + or b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5_0 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(2) + ): + for axis5_1_axis6_fused in T.vectorized(T.int64(128)): + with T.block("compute"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap( + "SSSS", [axis1, axis2, axis3, axis4] + ) + v_axis5 = T.axis.spatial( + T.int64(8), axis5_0 * T.int64(4) + axis5_1_axis6_fused // T.int64(32) + ) + v_axis6 = T.axis.spatial(T.int64(32), axis5_1_axis6_fused % T.int64(32)) + T.reads( + a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + T.writes( + compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + ) + compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] = ( + a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + - b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + ) + + @R.function + def main( + a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): + out = R.call_tir( + SubExpected.sub, + (a, b), + out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + ) + return out + + +@tvm.script.ir_module +class MulBefore: + @T.prim_func(private=True) + def mul( + a: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + b: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + compute: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + ): + T.func_attr( + { + "op_attrs": {"lhs_axis": 0, "op_name": "qnn.mul", "rhs_axis": 0}, + "op_pattern": 0, + "operator_name": "mul", + "tir.noalias": T.bool(True), + } + ) + # with T.block("root"): + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_A_assumptions"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + not ( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5 + ) + or a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_B_assumptions"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + not ( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5 + ) + or b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("compute"): + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads( + a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + T.writes(compute[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + compute[ + v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 + ] = T.if_then_else( + v_axis1 == T.int64(3) + and T.int64(4) <= v_axis4 + or v_axis2 == T.int64(3) + and T.int64(4) <= v_axis5, + T.uint8(0), + a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + * b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + + @R.function + def main( + a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): + out = R.call_tir( + MulBefore.mul, + (a, b), + out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + ) + return out + + +@tvm.script.ir_module +class MulExpected: + @T.prim_func(private=True) + def mul( + a: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + b: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + compute: T.Buffer( + (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32)), + "uint8", + ), + ): + T.func_attr( + { + "op_attrs": {"lhs_axis": 0, "op_name": "qnn.mul", "rhs_axis": 0}, + "op_pattern": 0, + "operator_name": "mul", + "tir.noalias": T.bool(True), + } + ) + # with T.block("root"): + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_A_assumptions"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) + and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) + or a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(8), T.int64(32) + ): + with T.block("buffer_B_assumptions"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = T.axis.remap( + "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6] + ) + T.reads(b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6]) + T.writes() + T.assume( + (v_axis1 < T.int64(3) or v_axis4 < T.int64(4)) + and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4)) + or b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + == T.uint8(0) + ) + + for axis0, axis1, axis2, axis3, axis4, axis5_0 in T.grid( + T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), T.int64(2) + ): + for axis5_1_axis6_fused in T.vectorized(T.int64(128)): + with T.block("compute"): + v_axis0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap( + "SSSS", [axis1, axis2, axis3, axis4] + ) + v_axis5 = T.axis.spatial( + T.int64(8), axis5_0 * T.int64(4) + axis5_1_axis6_fused // T.int64(32) + ) + v_axis6 = T.axis.spatial(T.int64(32), axis5_1_axis6_fused % T.int64(32)) + T.reads( + a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6], + ) + T.writes( + compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + ) + compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] = ( + a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + * b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6] + ) + + @R.function + def main( + a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"), + ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"): + out = R.call_tir( + MulExpected.mul, + (a, b), + out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"), + ) + return out + + +def test_add_primfunc_overcompute(): + add_after = tvm.tir.transform.UseAssumeToReduceBranches()(AddBefore) + tvm.ir.structural_equal(add_after["add"], AddExpected["add"], map_free_vars=True) + + +def test_sub_primfunc_overcompute(): + sub_after = tvm.tir.transform.UseAssumeToReduceBranches()(SubBefore) + tvm.ir.structural_equal(sub_after["sub"], SubExpected["sub"], map_free_vars=True) + + +def test_mul_primfunc_overcompute(): + mul_after = tvm.tir.transform.UseAssumeToReduceBranches()(MulBefore) + tvm.ir.structural_equal(mul_after["mul"], MulExpected["mul"], map_free_vars=True) + + +if __name__ == "__main__": + tvm.testing.main()