From 8937617c7671c734cc649d58882ed54c86b938e2 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 24 Feb 2020 16:10:02 +0800 Subject: [PATCH 01/25] Support large tensors --- include/tvm/tir/ir_pass.h | 2 + python/tvm/driver/build_module.py | 17 +++ src/target/llvm/codegen_llvm.cc | 3 +- src/tir/ir/buffer.cc | 2 +- src/tir/ir/expr.cc | 3 +- src/tir/pass/ffi_api.cc | 1 + src/tir/pass/loop_partition.cc | 2 +- src/tir/pass/rewrite_datatype.cc | 245 ++++++++++++++++++++++++++++++ src/tir/pass/unroll_loop.cc | 4 +- 9 files changed, 274 insertions(+), 5 deletions(-) create mode 100644 src/tir/pass/rewrite_datatype.cc diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h index 6e9a631fab4d..51ed40c1011f 100644 --- a/include/tvm/tir/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -386,6 +386,8 @@ Stmt DecorateDeviceScope(Stmt stmt); */ Stmt HoistIfThenElse(Stmt stmt); +Stmt DataTypeRewrite(Stmt stmt); + /*! * \brief Make an user callable API LoweredFunc. * diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 67eb22414abd..65bbdeef2e94 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -149,6 +149,7 @@ def lower(sch, # Phase 0 if isinstance(sch, schedule.Schedule): stmt = form_body(sch) + print("{:>32} = \n{}".format("form_body", stmt)) for f in lower_phase0: stmt = f(stmt) @@ -158,40 +159,56 @@ def lower(sch, # Phase 1 stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) + print("{:>32} = \n{}".format("RewriteForTensorCore", stmt)) stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) + print("{:>32} = \n{}".format("StorageFlatten", stmt)) + stmt = ir_pass.DataTypeRewrite(stmt) + print("{:>32} = \n{}".format("DataTypeRewrite", stmt)) stmt = ir_pass.CanonicalSimplify(stmt) + print("{:>32} = \n{}".format("CanonicalSimplify", stmt)) for f in lower_phase1: stmt = f(stmt) # Phase 2 if not simple_mode: stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop) + print("{:>32} = \n{}".format("LoopPartition", stmt)) if cfg.disable_vectorize: stmt = ir_pass.SkipVectorize(stmt) + print("{:>32} = \n{}".format("SkipVectorize", stmt)) else: stmt = ir_pass.VectorizeLoop(stmt) + print("{:>32} = \n{}".format("VectorizeLoop", stmt)) stmt = ir_pass.InjectVirtualThread(stmt) + print("{:>32} = \n{}".format("InjectVirtualThread", stmt)) stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop) + print("{:>32} = \n{}".format("InjectDoubleBuffer", stmt)) stmt = ir_pass.StorageRewrite(stmt) + print("{:>32} = \n{}".format("StorageRewrite", stmt)) stmt = ir_pass.UnrollLoop( stmt, cfg.auto_unroll_max_step, cfg.auto_unroll_max_depth, cfg.auto_unroll_max_extent, cfg.unroll_explicit) + print("{:>32} = \n{}".format("UnrollLoop", stmt)) for f in lower_phase2: stmt = f(stmt) # Phase 3 stmt = ir_pass.Simplify(stmt) + print("{:>32} = \n{}".format("Simplify", stmt)) stmt = ir_pass.RemoveNoOp(stmt) + print("{:>32} = \n{}".format("RemoveNoOp", stmt)) if not cfg.disable_select_rewriting: stmt = ir_pass.RewriteUnsafeSelect(stmt) + print("{:>32} = \n{}".format("RewriteUnsafeSelect", stmt)) for f in lower_phase3: stmt = f(stmt) # Instrument BoundCheckers if cfg.instrument_bound_checkers: stmt = ir_pass.InstrumentBoundCheckers(stmt) + print("{:>32} = \n{}".format("InstrumentBoundCheckers", stmt)) if simple_mode: return stmt diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 68d004cc481b..4dfcdf97ff13 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1121,7 +1121,8 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) { CHECK(op->for_type == ForType::Serial); } CreateSerialFor(MakeValue(op->min), MakeValue(op->extent), - ConstInt32(1), op->loop_var, op->body); + llvm::ConstantInt::getSigned(LLVMType(op->extent.dtype()), 1), + op->loop_var, op->body); } diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 19e32d6681ae..7f45e01b933c 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -448,7 +448,7 @@ Buffer BufferNode::make(Var data, n->buffer_type = buffer_type; if (n->buffer_type == kAutoBroadcast && n->shape.size() > 0 && n->strides.empty()) { for (size_t i = 0; i < n->shape.size(); ++i) { - n->strides.push_back(Var("stride")); + n->strides.push_back(Var("stride", n->shape[i].dtype())); } } return Buffer(n); diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index bee025687173..67449c059900 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -454,7 +454,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) auto* op = static_cast(node.get()); // omit the type // stream << op->name << "." << op->type; - p->stream << op->name_hint; + // p->stream << op->name_hint; + p->stream << op->name_hint << "." << op->dtype; }) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc index 46d0f67c6d51..2c711ae39b47 100644 --- a/src/tir/pass/ffi_api.cc +++ b/src/tir/pass/ffi_api.cc @@ -165,5 +165,6 @@ REGISTER_PASS(InstrumentBoundCheckers); REGISTER_PASS(VerifyCompactBuffer); REGISTER_PASS(HoistIfThenElse); REGISTER_PASS(InferFragment) +REGISTER_PASS(DataTypeRewrite); } // namespace tir } // namespace tvm diff --git a/src/tir/pass/loop_partition.cc b/src/tir/pass/loop_partition.cc index d1fa46e38860..e9157e796e38 100644 --- a/src/tir/pass/loop_partition.cc +++ b/src/tir/pass/loop_partition.cc @@ -587,7 +587,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt b // If the loop extent is 1, do not create the loop anymore return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}}); } else { - return ForNode::make(for_node->loop_var, 0, extent, + return ForNode::make(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, for_node->for_type, for_node->device_api, body); } } diff --git a/src/tir/pass/rewrite_datatype.cc b/src/tir/pass/rewrite_datatype.cc new file mode 100644 index 000000000000..c44b04ca2fa8 --- /dev/null +++ b/src/tir/pass/rewrite_datatype.cc @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rewrite_datatype.cc + */ + +#include +#include +#include "../../arith/ir_mutator_with_analyzer.h" +#include "../../arith/ir_visitor_with_analyzer.h" + +namespace tvm { +namespace tir { + +using arith::Analyzer; +using arith::IRMutatorWithAnalyzer; +using arith::ConstIntBound; + +class DataTypeVisitor final : public StmtExprVisitor { + public: + void VisitExpr(const PrimExpr& e) { + if (e.dtype().is_int()) { + int bits = 64; + if (analyzer_.CanProve(e <= max_value(DataType::Int(32)) && + e >= min_value(DataType::Int(32)))) { + bits = 32; + } + int tmp = bits_; + bits_ = bits > bits_ ? bits : bits_; + StmtExprVisitor::VisitExpr(e); + bits_ = tmp; + } else { + StmtExprVisitor::VisitExpr(e); + } + } + + void VisitStmt_(const ForNode* op) { + analyzer_.Bind(op->loop_var, + Range::make_by_min_extent(op->min, op->extent)); + vset_.insert(op->loop_var.as()); + return StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const AttrStmtNode* op) { + if (op->attr_key == attr::thread_extent || + op->attr_key == attr::virtual_thread) { + IterVar iv = Downcast(op->node); + CHECK_NE(iv->thread_tag.length(), 0U); + analyzer_.Bind(iv->var, + Range::make_by_min_extent(0, op->value)); + vset_.insert(iv->var.as()); + StmtExprVisitor::VisitStmt_(op); + } else { + StmtExprVisitor::VisitStmt_(op); + } + } + + void VisitExpr_(const ReduceNode* op) { + // Setup the domain information before simplification. + for (const IterVar& iv : op->axis) { + analyzer_.Bind(iv->var, iv->dom); + vset_.insert(iv->var.as()); + } + // Recursively call simplification when necessary. + StmtExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const VarNode* op) { + if (vset_.find(op) != vset_.end()) { + if (vmap.find(op) == vmap.end()) { + vmap[op] = DataType::Int(bits_); + } else { + vmap[op] = DataType::Int(std::max(vmap[op].bits(), bits_)); + } + } + StmtExprVisitor::VisitExpr_(op); + } + + std::unordered_map vmap; + + protected: + /*! \brief internal analyzer field. */ + arith::Analyzer analyzer_; + + private: + int bits_; + std::unordered_set vset_; +}; + +class DataTypeRewriter : public StmtExprMutator { + public: + Stmt operator()(Stmt s) { + visitor_(s); + return VisitStmt(s); + } + + Stmt VisitStmt_(const ForNode* op) final { + Stmt s = StmtExprMutator::VisitStmt_(op); + op = s.as(); + PrimExpr e = VisitExpr(op->loop_var); + Var var = Downcast(e); + return ForNode::make(var, Cast(op->min, var.dtype()), Cast(op->extent, var.dtype()), + op->for_type, op->device_api, op->body); + } + + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == attr::thread_extent || + op->attr_key == attr::virtual_thread) { + Stmt s = StmtExprMutator::VisitStmt_(op); + op = s.as(); + IterVar iv = Downcast(op->node); + PrimExpr e = VisitExpr(iv->var); + Var var = Downcast(e); + return AttrStmtNode::make( + IterVarNode::make(iv->dom, var, iv->iter_type, iv->thread_tag), + op->attr_key, + Cast(op->value, var.dtype()), + op->body); + } + return StmtExprMutator::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const VarNode* op) final { + if (visitor_.vmap.find(op) != visitor_.vmap.end()) { + if (vmap_.find(op) == vmap_.end()) { + vmap_[op] = Var(op->name_hint, visitor_.vmap[op]); + } + return vmap_[op]; + } + return StmtExprMutator::VisitExpr_(op); + } + + PrimExpr VisitExpr_(const SizeVarNode* op) final { + if (visitor_.vmap.find(op) != visitor_.vmap.end()) { + if (vmap_.find(op) == vmap_.end()) { + vmap_[op] = SizeVar(op->name_hint, visitor_.vmap[op]); + } + return vmap_[op]; + } + return StmtExprMutator::VisitExpr_(op); + } + + PrimExpr VisitExpr_(const AddNode* op) final; + PrimExpr VisitExpr_(const SubNode* op) final; + PrimExpr VisitExpr_(const MulNode* op) final; + PrimExpr VisitExpr_(const DivNode* op) final; + PrimExpr VisitExpr_(const ModNode* op) final; + PrimExpr VisitExpr_(const FloorDivNode* op) final; + PrimExpr VisitExpr_(const FloorModNode* op) final; + PrimExpr VisitExpr_(const MinNode* op) final; + PrimExpr VisitExpr_(const MaxNode* op) final; + PrimExpr VisitExpr_(const EQNode* op) final; + PrimExpr VisitExpr_(const NENode* op) final; + PrimExpr VisitExpr_(const LTNode* op) final; + PrimExpr VisitExpr_(const LENode* op) final; + PrimExpr VisitExpr_(const GTNode* op) final; + PrimExpr VisitExpr_(const GENode* op) final; + PrimExpr VisitExpr_(const CallNode* op) final; + + private: + DataTypeVisitor visitor_; + std::unordered_map vmap_; + PrimExpr Cast(PrimExpr e, DataType dtype) { + if (e.dtype() != dtype) { + return CastNode::make(dtype, e); + } else { + return e; + } + } +}; + +#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ + PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) { \ + PrimExpr a = this->VisitExpr(op->a); \ + PrimExpr b = this->VisitExpr(op->b); \ + if (a.same_as(op->a) && \ + b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + return FUNC(a, b); \ + } \ + } + +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=) + +PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) { + PrimExpr e = StmtExprMutator::VisitExpr_(op); + op = e.as(); + if (op->call_type == CallNode::PureIntrinsic) { + if (op->name == intrinsic::tvm_if_then_else) { + return if_then_else(op->args[0], op->args[1], op->args[2]); + } else if (op->name == CallNode::shift_right) { + return op->args[0] >> op->args[1]; + } else if (op->name == CallNode::shift_left) { + return op->args[0] << op->args[1]; + } else if (op->name == CallNode::bitwise_and) { + return op->args[0] & op->args[1]; + } else if (op->name == CallNode::bitwise_or) { + return op->args[0] | op->args[1]; + } else if (op->name == CallNode::bitwise_xor) { + return op->args[0] ^ op->args[1]; + } else if (op->name == "pow") { + return pow(op->args[0], op->args[1]); + } + } + return e; +} + +Stmt DataTypeRewrite(Stmt stmt) { + return DataTypeRewriter()(stmt); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/pass/unroll_loop.cc b/src/tir/pass/unroll_loop.cc index 3d669d03f1d6..0167dbcec5f2 100644 --- a/src/tir/pass/unroll_loop.cc +++ b/src/tir/pass/unroll_loop.cc @@ -160,7 +160,9 @@ class LoopUnroller : public StmtExprMutator { PrimExpr extent = tir::Simplify(op->extent); const IntImmNode *v1 = extent.as(); int value = -1; - if (v1 != nullptr) { + // integers that do not fit in int32_t are treated as symbolic, + // as it's impossible to unroll such large loops + if (v1 != nullptr && v1->value <= std::numeric_limits::max()) { value = static_cast(v1->value); } return value; From 17751bd1e73c27ffb1233af2240e5641bb4b955d Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Sun, 8 Mar 2020 01:38:25 +0800 Subject: [PATCH 02/25] Clear --- python/tvm/driver/build_module.py | 16 ---------------- src/tir/ir/expr.cc | 3 +-- 2 files changed, 1 insertion(+), 18 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 65bbdeef2e94..220de8eb1282 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -149,7 +149,6 @@ def lower(sch, # Phase 0 if isinstance(sch, schedule.Schedule): stmt = form_body(sch) - print("{:>32} = \n{}".format("form_body", stmt)) for f in lower_phase0: stmt = f(stmt) @@ -159,56 +158,41 @@ def lower(sch, # Phase 1 stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) - print("{:>32} = \n{}".format("RewriteForTensorCore", stmt)) stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) - print("{:>32} = \n{}".format("StorageFlatten", stmt)) stmt = ir_pass.DataTypeRewrite(stmt) - print("{:>32} = \n{}".format("DataTypeRewrite", stmt)) stmt = ir_pass.CanonicalSimplify(stmt) - print("{:>32} = \n{}".format("CanonicalSimplify", stmt)) for f in lower_phase1: stmt = f(stmt) # Phase 2 if not simple_mode: stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop) - print("{:>32} = \n{}".format("LoopPartition", stmt)) if cfg.disable_vectorize: stmt = ir_pass.SkipVectorize(stmt) - print("{:>32} = \n{}".format("SkipVectorize", stmt)) else: stmt = ir_pass.VectorizeLoop(stmt) - print("{:>32} = \n{}".format("VectorizeLoop", stmt)) stmt = ir_pass.InjectVirtualThread(stmt) - print("{:>32} = \n{}".format("InjectVirtualThread", stmt)) stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop) - print("{:>32} = \n{}".format("InjectDoubleBuffer", stmt)) stmt = ir_pass.StorageRewrite(stmt) - print("{:>32} = \n{}".format("StorageRewrite", stmt)) stmt = ir_pass.UnrollLoop( stmt, cfg.auto_unroll_max_step, cfg.auto_unroll_max_depth, cfg.auto_unroll_max_extent, cfg.unroll_explicit) - print("{:>32} = \n{}".format("UnrollLoop", stmt)) for f in lower_phase2: stmt = f(stmt) # Phase 3 stmt = ir_pass.Simplify(stmt) - print("{:>32} = \n{}".format("Simplify", stmt)) stmt = ir_pass.RemoveNoOp(stmt) - print("{:>32} = \n{}".format("RemoveNoOp", stmt)) if not cfg.disable_select_rewriting: stmt = ir_pass.RewriteUnsafeSelect(stmt) - print("{:>32} = \n{}".format("RewriteUnsafeSelect", stmt)) for f in lower_phase3: stmt = f(stmt) # Instrument BoundCheckers if cfg.instrument_bound_checkers: stmt = ir_pass.InstrumentBoundCheckers(stmt) - print("{:>32} = \n{}".format("InstrumentBoundCheckers", stmt)) if simple_mode: return stmt diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 67449c059900..bee025687173 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -454,8 +454,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) auto* op = static_cast(node.get()); // omit the type // stream << op->name << "." << op->type; - // p->stream << op->name_hint; - p->stream << op->name_hint << "." << op->dtype; + p->stream << op->name_hint; }) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); From 170976401e0349227aaffe157edbc04ace634acb Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Sun, 8 Mar 2020 01:40:44 +0800 Subject: [PATCH 03/25] Add tests --- src/tir/pass/rewrite_datatype.cc | 3 +- .../unittest/test_pass_datatype_rewrite.py | 101 ++++++++++++++++++ 2 files changed, 103 insertions(+), 1 deletion(-) create mode 100644 tests/python/unittest/test_pass_datatype_rewrite.py diff --git a/src/tir/pass/rewrite_datatype.cc b/src/tir/pass/rewrite_datatype.cc index c44b04ca2fa8..bc630c7265d9 100644 --- a/src/tir/pass/rewrite_datatype.cc +++ b/src/tir/pass/rewrite_datatype.cc @@ -38,7 +38,8 @@ class DataTypeVisitor final : public StmtExprVisitor { void VisitExpr(const PrimExpr& e) { if (e.dtype().is_int()) { int bits = 64; - if (analyzer_.CanProve(e <= max_value(DataType::Int(32)) && + if (e.dtype() == DataType::Int(32) || + analyzer_.CanProve(e <= max_value(DataType::Int(32)) && e >= min_value(DataType::Int(32)))) { bits = 32; } diff --git a/tests/python/unittest/test_pass_datatype_rewrite.py b/tests/python/unittest/test_pass_datatype_rewrite.py new file mode 100644 index 000000000000..abd25af0a2c6 --- /dev/null +++ b/tests/python/unittest/test_pass_datatype_rewrite.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te + + +def lower(sch, args): + binds = {} + arg_list = [] + for x in args: + if isinstance(x, te.tensor.Tensor): + buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name) + assert x not in binds + binds[x] = buf + arg_list.append(buf) + else: + raise ValueError("args must be Tensor, Buffer or Var") + bounds = te.schedule.InferBound(sch) + stmt = te.schedule.ScheduleOps(sch, bounds) + stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False) + stmt = tvm.tir.ir_pass.DataTypeRewrite(stmt) + return stmt + + +def test_const(): + m, n = 2, 2 + A = te.placeholder((m, n), name='A') + B = te.placeholder((m, n), name='B') + C = te.compute((m, n), lambda *idx: A[idx] + B[idx]) + s = te.create_schedule(C.op) + stmt = lower(s, [A, B, C]) + assert stmt.body.loop_var.dtype == "int32" + assert stmt.body.body.loop_var.dtype == "int32" + m, n = 2**16, 2**16 + A = te.placeholder((m, n), name='A') + B = te.placeholder((m, n), name='B') + C = te.compute((m, n), lambda *idx: A[idx] + B[idx]) + s = te.create_schedule(C.op) + stmt = lower(s, [A, B, C]) + assert stmt.body.loop_var.dtype == "int64" + assert stmt.body.body.loop_var.dtype == "int64" + + +def test_symbolic(): + m, n = te.size_var(name='m'), te.size_var(name='n') + A = te.placeholder((m, n), name='A') + B = te.placeholder((m, n), name='B') + C = te.compute((m, n), lambda *idx: A[idx] + B[idx]) + s = te.create_schedule(C.op) + stmt = lower(s, [A, B, C]) + assert stmt.body.loop_var.dtype == "int64" + assert stmt.body.body.loop_var.dtype == "int64" + + +def test_thread_axis_2dim(): + m, n = 1024, 32 + A = te.placeholder((m, n), name='A') + B = te.placeholder((m, n), name='B') + C = te.compute((m, n), lambda *idx: A[idx] + B[idx]) + s = te.create_schedule(C.op) + s[C].bind(C.op.axis[0], te.thread_axis("blockIdx.x")) + s[C].bind(C.op.axis[1], te.thread_axis("threadIdx.x")) + stmt = lower(s, [A, B, C]) + assert stmt.body.node.var.dtype == "int32" + assert stmt.body.body.node.var.dtype == "int32" + + +def test_thread_axis_3dim(): + m, n, k = 2**12, 2**12, 2**13 + A = te.placeholder((m, n, k), name='A') + B = te.placeholder((m, n, k), name='B') + C = te.compute((m, n, k), lambda *idx: A[idx] + B[idx]) + s = te.create_schedule(C.op) + fused = s[C].fuse(*[axis for axis in C.op.axis]) + xo, xi = s[C].split(fused, factor=32) + s[C].bind(xo, te.thread_axis("blockIdx.x")) + s[C].bind(xi, te.thread_axis("threadIdx.x")) + stmt = lower(s, [A, B, C]) + assert stmt.body.node.var.dtype == "int64" + assert stmt.body.body.node.var.dtype == "int64" + + +if __name__ == "__main__": + test_const() + test_symbolic() + test_thread_axis_2dim() + test_thread_axis_3dim() From 4d420db766cb798c86f827b3212602967490ac92 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Wed, 11 Mar 2020 00:21:08 +0800 Subject: [PATCH 04/25] Change tests --- .../unittest/test_pass_datatype_rewrite.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/python/unittest/test_pass_datatype_rewrite.py b/tests/python/unittest/test_pass_datatype_rewrite.py index abd25af0a2c6..fdd2cd3ca4f3 100644 --- a/tests/python/unittest/test_pass_datatype_rewrite.py +++ b/tests/python/unittest/test_pass_datatype_rewrite.py @@ -51,12 +51,21 @@ def test_const(): C = te.compute((m, n), lambda *idx: A[idx] + B[idx]) s = te.create_schedule(C.op) stmt = lower(s, [A, B, C]) - assert stmt.body.loop_var.dtype == "int64" - assert stmt.body.body.loop_var.dtype == "int64" - + # i32 + i32 is not promoted to i64 even in the case of overflow + assert stmt.body.loop_var.dtype == "int32" + assert stmt.body.body.loop_var.dtype == "int32" + def test_symbolic(): - m, n = te.size_var(name='m'), te.size_var(name='n') + m, n = te.size_var(name='m', dtype='int32'), te.size_var(name='n', dtype='int32') + A = te.placeholder((m, n), name='A') + B = te.placeholder((m, n), name='B') + C = te.compute((m, n), lambda *idx: A[idx] + B[idx]) + s = te.create_schedule(C.op) + stmt = lower(s, [A, B, C]) + assert stmt.body.loop_var.dtype == "int32" + assert stmt.body.body.loop_var.dtype == "int32" + m, n = te.size_var(name='m', dtype='int64'), te.size_var(name='n', dtype='int64') A = te.placeholder((m, n), name='A') B = te.placeholder((m, n), name='B') C = te.compute((m, n), lambda *idx: A[idx] + B[idx]) @@ -90,8 +99,9 @@ def test_thread_axis_3dim(): s[C].bind(xo, te.thread_axis("blockIdx.x")) s[C].bind(xi, te.thread_axis("threadIdx.x")) stmt = lower(s, [A, B, C]) - assert stmt.body.node.var.dtype == "int64" - assert stmt.body.body.node.var.dtype == "int64" + # i32 + i32 is not promoted to i64 even in the case of overflow + assert stmt.body.node.var.dtype == "int32" + assert stmt.body.body.node.var.dtype == "int32" if __name__ == "__main__": From b0dd985d358fc099ae33f711296912fa26083bdd Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Wed, 11 Mar 2020 13:30:10 +0800 Subject: [PATCH 05/25] Add IntImm and Cast --- src/tir/pass/rewrite_datatype.cc | 40 ++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/src/tir/pass/rewrite_datatype.cc b/src/tir/pass/rewrite_datatype.cc index bc630c7265d9..d4049a4cb9b4 100644 --- a/src/tir/pass/rewrite_datatype.cc +++ b/src/tir/pass/rewrite_datatype.cc @@ -94,6 +94,30 @@ class DataTypeVisitor final : public StmtExprVisitor { StmtExprVisitor::VisitExpr_(op); } + void VisitExpr_(const IntImmNode* op) { + if (op->dtype.is_int()) { + int bits = std::min(op->dtype.bits(), bits_); + if (vmap.find(op) == vmap.end()) { + vmap[op] = DataType::Int(bits); + } else { + vmap[op] = DataType::Int(std::max(vmap[op].bits(), bits)); + } + } + StmtExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const CastNode* op) { + if (op->dtype.is_int()) { + int bits = std::min(op->dtype.bits(), bits_); + if (vmap.find(op) == vmap.end()) { + vmap[op] = DataType::Int(bits); + } else { + vmap[op] = DataType::Int(std::max(vmap[op].bits(), bits)); + } + } + StmtExprVisitor::VisitExpr_(op); + } + std::unordered_map vmap; protected: @@ -158,6 +182,22 @@ class DataTypeRewriter : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } + PrimExpr VisitExpr_(const IntImmNode* op) final { + if (visitor_.vmap.find(op) != visitor_.vmap.end()) { + return IntImm(visitor_.vmap[op], op->value); + } + return StmtExprMutator::VisitExpr_(op); + } + + PrimExpr VisitExpr_(const CastNode* op) final { + if (visitor_.vmap.find(op) != visitor_.vmap.end()) { + PrimExpr e = StmtExprMutator::VisitExpr_(op); + const CastNode* new_op = e.as(); + return CastNode::make(visitor_.vmap[op], new_op->value); + } + return StmtExprMutator::VisitExpr_(op); + } + PrimExpr VisitExpr_(const AddNode* op) final; PrimExpr VisitExpr_(const SubNode* op) final; PrimExpr VisitExpr_(const MulNode* op) final; From 10ac54f1358c49186e8d32fd308d7302b5dbfd0b Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Wed, 11 Mar 2020 17:19:10 +0800 Subject: [PATCH 06/25] Fix multi-lanes dtype --- src/tir/pass/rewrite_datatype.cc | 17 ++++++++++------- .../unittest/test_pass_datatype_rewrite.py | 13 +++++++++++++ 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/src/tir/pass/rewrite_datatype.cc b/src/tir/pass/rewrite_datatype.cc index d4049a4cb9b4..0e1c116eabb3 100644 --- a/src/tir/pass/rewrite_datatype.cc +++ b/src/tir/pass/rewrite_datatype.cc @@ -33,12 +33,14 @@ using arith::Analyzer; using arith::IRMutatorWithAnalyzer; using arith::ConstIntBound; +class DataTypeRewriter; + class DataTypeVisitor final : public StmtExprVisitor { public: void VisitExpr(const PrimExpr& e) { if (e.dtype().is_int()) { int bits = 64; - if (e.dtype() == DataType::Int(32) || + if (e.dtype().bits() <= 32 || analyzer_.CanProve(e <= max_value(DataType::Int(32)) && e >= min_value(DataType::Int(32)))) { bits = 32; @@ -86,9 +88,9 @@ class DataTypeVisitor final : public StmtExprVisitor { void VisitExpr_(const VarNode* op) { if (vset_.find(op) != vset_.end()) { if (vmap.find(op) == vmap.end()) { - vmap[op] = DataType::Int(bits_); + vmap[op] = op->dtype.with_bits(bits_); } else { - vmap[op] = DataType::Int(std::max(vmap[op].bits(), bits_)); + vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits_)); } } StmtExprVisitor::VisitExpr_(op); @@ -98,9 +100,9 @@ class DataTypeVisitor final : public StmtExprVisitor { if (op->dtype.is_int()) { int bits = std::min(op->dtype.bits(), bits_); if (vmap.find(op) == vmap.end()) { - vmap[op] = DataType::Int(bits); + vmap[op] = op->dtype.with_bits(bits); } else { - vmap[op] = DataType::Int(std::max(vmap[op].bits(), bits)); + vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits)); } } StmtExprVisitor::VisitExpr_(op); @@ -110,9 +112,9 @@ class DataTypeVisitor final : public StmtExprVisitor { if (op->dtype.is_int()) { int bits = std::min(op->dtype.bits(), bits_); if (vmap.find(op) == vmap.end()) { - vmap[op] = DataType::Int(bits); + vmap[op] = op->dtype.with_bits(bits); } else { - vmap[op] = DataType::Int(std::max(vmap[op].bits(), bits)); + vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits)); } } StmtExprVisitor::VisitExpr_(op); @@ -127,6 +129,7 @@ class DataTypeVisitor final : public StmtExprVisitor { private: int bits_; std::unordered_set vset_; + friend class DataTypeRewriter; }; class DataTypeRewriter : public StmtExprMutator { diff --git a/tests/python/unittest/test_pass_datatype_rewrite.py b/tests/python/unittest/test_pass_datatype_rewrite.py index fdd2cd3ca4f3..30e492ea7d91 100644 --- a/tests/python/unittest/test_pass_datatype_rewrite.py +++ b/tests/python/unittest/test_pass_datatype_rewrite.py @@ -104,8 +104,21 @@ def test_thread_axis_3dim(): assert stmt.body.body.node.var.dtype == "int32" +def test_vectorize(): + def test(m, lanes, dtype): + A = te.placeholder((m,), name='A', dtype='float32x{}'.format(lanes)) + B = te.placeholder((m,), name='B', dtype='float32x{}'.format(lanes)) + C = te.compute((m,), lambda *idx: A[idx] + B[idx]) + s = te.create_schedule(C.op) + stmt = lower(s, [A, B, C]) + assert stmt.body.loop_var.dtype == dtype + test(tvm.tir.const(64, dtype='int32'), 2, 'int32') + test(tvm.tir.const(2 ** 32, dtype='int64'), 2, 'int64') + + if __name__ == "__main__": test_const() test_symbolic() test_thread_axis_2dim() test_thread_axis_3dim() + test_vectorize() From d2c82374a8ee519e59b283856c5891e562b1d2b0 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Sat, 14 Mar 2020 19:32:19 +0800 Subject: [PATCH 07/25] Add comments --- include/tvm/tir/ir_pass.h | 5 + src/tir/pass/rewrite_datatype.cc | 19 +-- .../unittest/test_pass_datatype_rewrite.py | 116 ++++++++---------- 3 files changed, 69 insertions(+), 71 deletions(-) diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h index 51ed40c1011f..06ffa7534cf0 100644 --- a/include/tvm/tir/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -386,6 +386,11 @@ Stmt DecorateDeviceScope(Stmt stmt); */ Stmt HoistIfThenElse(Stmt stmt); +/*! + * \brief Narrow down PrimExpr datatype in stmt + * \param stmt The stmt to do datatype rewrite + * \return Transformed stmt. + */ Stmt DataTypeRewrite(Stmt stmt); /*! diff --git a/src/tir/pass/rewrite_datatype.cc b/src/tir/pass/rewrite_datatype.cc index 0e1c116eabb3..04a5e54e3f48 100644 --- a/src/tir/pass/rewrite_datatype.cc +++ b/src/tir/pass/rewrite_datatype.cc @@ -119,15 +119,17 @@ class DataTypeVisitor final : public StmtExprVisitor { } StmtExprVisitor::VisitExpr_(op); } - + // the narrowed datatype of Var, IntImm, and Cast std::unordered_map vmap; protected: - /*! \brief internal analyzer field. */ + // internal analyzer arith::Analyzer analyzer_; private: + // the maximum bits of all containing expressions int bits_; + // the vars to be rewritten std::unordered_set vset_; friend class DataTypeRewriter; }; @@ -219,7 +221,10 @@ class DataTypeRewriter : public StmtExprMutator { PrimExpr VisitExpr_(const CallNode* op) final; private: + // the internal visitor to deduce the narrowed dtype DataTypeVisitor visitor_; + // a map from Var before rewrite to Var after rewrite, + // ensures one old Var maps to exactly one new Var std::unordered_map vmap_; PrimExpr Cast(PrimExpr e, DataType dtype) { if (e.dtype() != dtype) { @@ -230,15 +235,15 @@ class DataTypeRewriter : public StmtExprMutator { } }; -#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ - PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) { \ +#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ + PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) { \ PrimExpr a = this->VisitExpr(op->a); \ PrimExpr b = this->VisitExpr(op->b); \ if (a.same_as(op->a) && \ b.same_as(op->b)) { \ return GetRef(op); \ } else { \ - return FUNC(a, b); \ + return FUNC(a, b); \ } \ } @@ -253,9 +258,9 @@ DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min) DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max) DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==) DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator <) DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator >) DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=) PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) { diff --git a/tests/python/unittest/test_pass_datatype_rewrite.py b/tests/python/unittest/test_pass_datatype_rewrite.py index 30e492ea7d91..4ddc64171ff8 100644 --- a/tests/python/unittest/test_pass_datatype_rewrite.py +++ b/tests/python/unittest/test_pass_datatype_rewrite.py @@ -37,74 +37,62 @@ def lower(sch, args): def test_const(): - m, n = 2, 2 - A = te.placeholder((m, n), name='A') - B = te.placeholder((m, n), name='B') - C = te.compute((m, n), lambda *idx: A[idx] + B[idx]) - s = te.create_schedule(C.op) - stmt = lower(s, [A, B, C]) - assert stmt.body.loop_var.dtype == "int32" - assert stmt.body.body.loop_var.dtype == "int32" - m, n = 2**16, 2**16 - A = te.placeholder((m, n), name='A') - B = te.placeholder((m, n), name='B') - C = te.compute((m, n), lambda *idx: A[idx] + B[idx]) - s = te.create_schedule(C.op) - stmt = lower(s, [A, B, C]) + def test(m, n, dtype): + A = te.placeholder((m, n), name='A') + B = te.placeholder((m, n), name='B') + C = te.compute((m, n), lambda *idx: A[idx] + B[idx]) + s = te.create_schedule(C.op) + stmt = lower(s, [A, B, C]) + assert stmt.body.loop_var.dtype == dtype + assert stmt.body.body.loop_var.dtype == dtype + + test(2, 2, "int32") # i32 + i32 is not promoted to i64 even in the case of overflow - assert stmt.body.loop_var.dtype == "int32" - assert stmt.body.body.loop_var.dtype == "int32" + test(2**16, 2**16, "int32") + test(tvm.tir.const(2, dtype='int64'), tvm.tir.const(2, dtype='int64'), "int32") + test(tvm.tir.const(2**16, dtype='int64'), tvm.tir.const(2**16, dtype='int64'), "int64") def test_symbolic(): - m, n = te.size_var(name='m', dtype='int32'), te.size_var(name='n', dtype='int32') - A = te.placeholder((m, n), name='A') - B = te.placeholder((m, n), name='B') - C = te.compute((m, n), lambda *idx: A[idx] + B[idx]) - s = te.create_schedule(C.op) - stmt = lower(s, [A, B, C]) - assert stmt.body.loop_var.dtype == "int32" - assert stmt.body.body.loop_var.dtype == "int32" - m, n = te.size_var(name='m', dtype='int64'), te.size_var(name='n', dtype='int64') - A = te.placeholder((m, n), name='A') - B = te.placeholder((m, n), name='B') - C = te.compute((m, n), lambda *idx: A[idx] + B[idx]) - s = te.create_schedule(C.op) - stmt = lower(s, [A, B, C]) - assert stmt.body.loop_var.dtype == "int64" - assert stmt.body.body.loop_var.dtype == "int64" - - -def test_thread_axis_2dim(): - m, n = 1024, 32 - A = te.placeholder((m, n), name='A') - B = te.placeholder((m, n), name='B') - C = te.compute((m, n), lambda *idx: A[idx] + B[idx]) - s = te.create_schedule(C.op) - s[C].bind(C.op.axis[0], te.thread_axis("blockIdx.x")) - s[C].bind(C.op.axis[1], te.thread_axis("threadIdx.x")) - stmt = lower(s, [A, B, C]) - assert stmt.body.node.var.dtype == "int32" - assert stmt.body.body.node.var.dtype == "int32" - - -def test_thread_axis_3dim(): - m, n, k = 2**12, 2**12, 2**13 - A = te.placeholder((m, n, k), name='A') - B = te.placeholder((m, n, k), name='B') - C = te.compute((m, n, k), lambda *idx: A[idx] + B[idx]) - s = te.create_schedule(C.op) - fused = s[C].fuse(*[axis for axis in C.op.axis]) - xo, xi = s[C].split(fused, factor=32) - s[C].bind(xo, te.thread_axis("blockIdx.x")) - s[C].bind(xi, te.thread_axis("threadIdx.x")) - stmt = lower(s, [A, B, C]) + def test(m, n, dtype): + A = te.placeholder((m, n), name='A') + B = te.placeholder((m, n), name='B') + C = te.compute((m, n), lambda *idx: A[idx] + B[idx]) + s = te.create_schedule(C.op) + stmt = lower(s, [A, B, C]) + assert stmt.body.loop_var.dtype == dtype + assert stmt.body.body.loop_var.dtype == dtype + + test(te.size_var(name='m', dtype='int32'), te.size_var(name='n', dtype='int32'), "int32") + test(te.size_var(name='m', dtype='int64'), te.size_var(name='n', dtype='int64'), "int64") + + +def test_thread_axis(): + def test(m, n, k, dtype): + A = te.placeholder((m, n, k), name='A') + B = te.placeholder((m, n, k), name='B') + C = te.compute((m, n, k), lambda *idx: A[idx] + B[idx]) + s = te.create_schedule(C.op) + fused = s[C].fuse(*[axis for axis in C.op.axis]) + xo, xi = s[C].split(fused, factor=32) + s[C].bind(xo, te.thread_axis("blockIdx.x")) + s[C].bind(xi, te.thread_axis("threadIdx.x")) + stmt = lower(s, [A, B, C]) + + test(2, 2, 2, dtype='int32') # i32 + i32 is not promoted to i64 even in the case of overflow - assert stmt.body.node.var.dtype == "int32" - assert stmt.body.body.node.var.dtype == "int32" + test(2**10, 2**11, 2**12, dtype='int32') + test(tvm.tir.const(2, dtype='int64'), + tvm.tir.const(2, dtype='int64'), + tvm.tir.const(2, dtype='int64'), + dtype='int32') + test(tvm.tir.const(2**10, dtype='int64'), + tvm.tir.const(2**11, dtype='int64'), + tvm.tir.const(2**12, dtype='int64'), + dtype='int64') -def test_vectorize(): +def test_multilanes(): def test(m, lanes, dtype): A = te.placeholder((m,), name='A', dtype='float32x{}'.format(lanes)) B = te.placeholder((m,), name='B', dtype='float32x{}'.format(lanes)) @@ -112,6 +100,7 @@ def test(m, lanes, dtype): s = te.create_schedule(C.op) stmt = lower(s, [A, B, C]) assert stmt.body.loop_var.dtype == dtype + test(tvm.tir.const(64, dtype='int32'), 2, 'int32') test(tvm.tir.const(2 ** 32, dtype='int64'), 2, 'int64') @@ -119,6 +108,5 @@ def test(m, lanes, dtype): if __name__ == "__main__": test_const() test_symbolic() - test_thread_axis_2dim() - test_thread_axis_3dim() - test_vectorize() + test_thread_axis() + test_multilanes() From ca7a74d239ca20b4fc048dc77409d7e0e701f6dd Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 16 Mar 2020 17:11:04 +0800 Subject: [PATCH 08/25] Restricted to StoreNode and LoadNode --- src/tir/pass/rewrite_datatype.cc | 46 +++++++++++++++++++------------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/src/tir/pass/rewrite_datatype.cc b/src/tir/pass/rewrite_datatype.cc index 04a5e54e3f48..76c87c15c7b6 100644 --- a/src/tir/pass/rewrite_datatype.cc +++ b/src/tir/pass/rewrite_datatype.cc @@ -108,18 +108,7 @@ class DataTypeVisitor final : public StmtExprVisitor { StmtExprVisitor::VisitExpr_(op); } - void VisitExpr_(const CastNode* op) { - if (op->dtype.is_int()) { - int bits = std::min(op->dtype.bits(), bits_); - if (vmap.find(op) == vmap.end()) { - vmap[op] = op->dtype.with_bits(bits); - } else { - vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits)); - } - } - StmtExprVisitor::VisitExpr_(op); - } - // the narrowed datatype of Var, IntImm, and Cast + // the narrowed datatype of Var and IntImm std::unordered_map vmap; protected: @@ -141,6 +130,18 @@ class DataTypeRewriter : public StmtExprMutator { return VisitStmt(s); } + Stmt VisitStmt_(const StoreNode* op) final { + PrimExpr value = this->VisitExpr(op->value); + is_index_ = true; + PrimExpr index = this->VisitExpr(op->index); + is_index_ = false; + Stmt s = StoreNode::make(op->buffer_var, + op->value, + index, + op->predicate); + return StmtExprMutator::VisitStmt_(s.as()); + } + Stmt VisitStmt_(const ForNode* op) final { Stmt s = StmtExprMutator::VisitStmt_(op); op = s.as(); @@ -187,18 +188,26 @@ class DataTypeRewriter : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } + PrimExpr VisitExpr_(const LoadNode* op) final { + is_index_ = true; + PrimExpr index = this->VisitExpr(op->index); + is_index_ = false; + PrimExpr e = LoadNode::make(op->dtype, op->buffer_var, index, op->predicate); + return StmtExprMutator::VisitExpr_(e.as()); + } + PrimExpr VisitExpr_(const IntImmNode* op) final { - if (visitor_.vmap.find(op) != visitor_.vmap.end()) { - return IntImm(visitor_.vmap[op], op->value); + if (is_index_) { + if (visitor_.vmap.find(op) != visitor_.vmap.end()) { + return IntImm(visitor_.vmap[op], op->value); + } } return StmtExprMutator::VisitExpr_(op); } PrimExpr VisitExpr_(const CastNode* op) final { - if (visitor_.vmap.find(op) != visitor_.vmap.end()) { - PrimExpr e = StmtExprMutator::VisitExpr_(op); - const CastNode* new_op = e.as(); - return CastNode::make(visitor_.vmap[op], new_op->value); + if (is_index_) { + return StmtExprMutator::VisitExpr(op->value); } return StmtExprMutator::VisitExpr_(op); } @@ -226,6 +235,7 @@ class DataTypeRewriter : public StmtExprMutator { // a map from Var before rewrite to Var after rewrite, // ensures one old Var maps to exactly one new Var std::unordered_map vmap_; + bool is_index_{false}; PrimExpr Cast(PrimExpr e, DataType dtype) { if (e.dtype() != dtype) { return CastNode::make(dtype, e); From ad30a835e50eeb697d42ce4c132e60fa5f563eb6 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Tue, 17 Mar 2020 22:24:32 +0800 Subject: [PATCH 09/25] Only narrow, no promotion --- src/target/llvm/llvm_module.cc | 1 - src/tir/pass/rewrite_datatype.cc | 11 ++++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 3f508a530c74..50cd3e7b1ad8 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -237,7 +237,6 @@ class LLVMModuleNode final : public runtime::ModuleNode { if (tm_->getTargetTriple().isOSDarwin()) { module_->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); } - std::string verify_errors_storage; llvm::raw_string_ostream verify_errors(verify_errors_storage); LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors)) diff --git a/src/tir/pass/rewrite_datatype.cc b/src/tir/pass/rewrite_datatype.cc index 76c87c15c7b6..6de019ac7cd3 100644 --- a/src/tir/pass/rewrite_datatype.cc +++ b/src/tir/pass/rewrite_datatype.cc @@ -58,6 +58,7 @@ class DataTypeVisitor final : public StmtExprVisitor { analyzer_.Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); vset_.insert(op->loop_var.as()); + vextent_[op->loop_var.as()] = op->extent.dtype(); return StmtExprVisitor::VisitStmt_(op); } @@ -69,6 +70,7 @@ class DataTypeVisitor final : public StmtExprVisitor { analyzer_.Bind(iv->var, Range::make_by_min_extent(0, op->value)); vset_.insert(iv->var.as()); + vextent_[iv->var.as()] = op->value.dtype(); StmtExprVisitor::VisitStmt_(op); } else { StmtExprVisitor::VisitStmt_(op); @@ -80,6 +82,7 @@ class DataTypeVisitor final : public StmtExprVisitor { for (const IterVar& iv : op->axis) { analyzer_.Bind(iv->var, iv->dom); vset_.insert(iv->var.as()); + vextent_[iv->var.as()] = iv->dom->extent.dtype(); } // Recursively call simplification when necessary. StmtExprVisitor::VisitExpr_(op); @@ -87,10 +90,11 @@ class DataTypeVisitor final : public StmtExprVisitor { void VisitExpr_(const VarNode* op) { if (vset_.find(op) != vset_.end()) { + int bits = std::min(vextent_[op].bits(), bits_); if (vmap.find(op) == vmap.end()) { - vmap[op] = op->dtype.with_bits(bits_); + vmap[op] = op->dtype.with_bits(bits); } else { - vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits_)); + vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits)); } } StmtExprVisitor::VisitExpr_(op); @@ -120,7 +124,8 @@ class DataTypeVisitor final : public StmtExprVisitor { int bits_; // the vars to be rewritten std::unordered_set vset_; - friend class DataTypeRewriter; + // the extent of vars to be rewritten + std::unordered_map vextent_; }; class DataTypeRewriter : public StmtExprMutator { From 73a526ed3cf70557a8eaf066f97b56deab517d93 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Tue, 17 Mar 2020 23:58:51 +0800 Subject: [PATCH 10/25] Fix CodeGenCPU --- src/target/llvm/codegen_cpu.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index ba3dee73dbe5..863ac950bc7f 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -943,7 +943,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { PrimExpr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent); CreateSerialFor(MakeValue(begin), MakeValue(end), - ConstInt32(1), + llvm::ConstantInt::getSigned(LLVMType(end.dtype()), 1), op->loop_var, op->body); } From 03e9093c05fd7f905a60524d8996bd23bbddedca Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Wed, 18 Mar 2020 13:00:50 +0800 Subject: [PATCH 11/25] Fix cast --- src/tir/pass/rewrite_datatype.cc | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/tir/pass/rewrite_datatype.cc b/src/tir/pass/rewrite_datatype.cc index 6de019ac7cd3..5e7375a371bc 100644 --- a/src/tir/pass/rewrite_datatype.cc +++ b/src/tir/pass/rewrite_datatype.cc @@ -112,6 +112,18 @@ class DataTypeVisitor final : public StmtExprVisitor { StmtExprVisitor::VisitExpr_(op); } + void VisitExpr_(const CastNode* op) { + if (op->dtype.is_int()) { + int bits = std::min(op->dtype.bits(), bits_); + if (vmap.find(op) == vmap.end()) { + vmap[op] = op->dtype.with_bits(bits); + } else { + vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits)); + } + } + StmtExprVisitor::VisitExpr_(op); + } + // the narrowed datatype of Var and IntImm std::unordered_map vmap; @@ -211,8 +223,10 @@ class DataTypeRewriter : public StmtExprMutator { } PrimExpr VisitExpr_(const CastNode* op) final { - if (is_index_) { - return StmtExprMutator::VisitExpr(op->value); + if (is_index_ && visitor_.vmap.find(op) != visitor_.vmap.end()) { + PrimExpr e = StmtExprMutator::VisitExpr_(op); + const CastNode* new_op = e.as(); + return CastNode::make(visitor_.vmap[op], new_op->value); } return StmtExprMutator::VisitExpr_(op); } From 4cdeea341b70c7a71acf4a4958736345ad88fbd5 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Wed, 18 Mar 2020 17:47:38 +0800 Subject: [PATCH 12/25] Add reduction tests --- .../unittest/test_pass_datatype_rewrite.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/python/unittest/test_pass_datatype_rewrite.py b/tests/python/unittest/test_pass_datatype_rewrite.py index 4ddc64171ff8..69eee8cc19d6 100644 --- a/tests/python/unittest/test_pass_datatype_rewrite.py +++ b/tests/python/unittest/test_pass_datatype_rewrite.py @@ -105,8 +105,24 @@ def test(m, lanes, dtype): test(tvm.tir.const(2 ** 32, dtype='int64'), 2, 'int64') +def test_reduce(): + def test(m, dtype): + A = te.placeholder((m,), name='A', dtype='float32') + k = te.reduce_axis((0, m), "k") + B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name='B') + s = te.create_schedule(B.op) + stmt = lower(s, [A, B]) + assert stmt.body[1].loop_var.dtype == dtype + + test(tvm.tir.const(64, dtype='int32'), 'int32') + test(tvm.tir.const(64, dtype='int64'), 'int32') + test(te.var('n', dtype='int32'), 'int32') + test(te.var('n', dtype='int64'), 'int64') + + if __name__ == "__main__": test_const() test_symbolic() test_thread_axis() test_multilanes() + test_reduce() From 686a6fb5c6d2f17f875310499fe40395999dbece Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Wed, 18 Mar 2020 21:48:30 +0800 Subject: [PATCH 13/25] Rename --- src/tir/pass/rewrite_datatype.cc | 11 ++--------- ...e_rewrite.py => test_tir_pass_rewrite_datatype.py} | 0 2 files changed, 2 insertions(+), 9 deletions(-) rename tests/python/unittest/{test_pass_datatype_rewrite.py => test_tir_pass_rewrite_datatype.py} (100%) diff --git a/src/tir/pass/rewrite_datatype.cc b/src/tir/pass/rewrite_datatype.cc index 5e7375a371bc..9077408d704c 100644 --- a/src/tir/pass/rewrite_datatype.cc +++ b/src/tir/pass/rewrite_datatype.cc @@ -164,7 +164,7 @@ class DataTypeRewriter : public StmtExprMutator { op = s.as(); PrimExpr e = VisitExpr(op->loop_var); Var var = Downcast(e); - return ForNode::make(var, Cast(op->min, var.dtype()), Cast(op->extent, var.dtype()), + return ForNode::make(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), op->for_type, op->device_api, op->body); } @@ -179,7 +179,7 @@ class DataTypeRewriter : public StmtExprMutator { return AttrStmtNode::make( IterVarNode::make(iv->dom, var, iv->iter_type, iv->thread_tag), op->attr_key, - Cast(op->value, var.dtype()), + cast(var.dtype(), op->value), op->body); } return StmtExprMutator::VisitStmt_(op); @@ -255,13 +255,6 @@ class DataTypeRewriter : public StmtExprMutator { // ensures one old Var maps to exactly one new Var std::unordered_map vmap_; bool is_index_{false}; - PrimExpr Cast(PrimExpr e, DataType dtype) { - if (e.dtype() != dtype) { - return CastNode::make(dtype, e); - } else { - return e; - } - } }; #define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ diff --git a/tests/python/unittest/test_pass_datatype_rewrite.py b/tests/python/unittest/test_tir_pass_rewrite_datatype.py similarity index 100% rename from tests/python/unittest/test_pass_datatype_rewrite.py rename to tests/python/unittest/test_tir_pass_rewrite_datatype.py From 725f42d1d73eb0e32ff0a5d95bdca2e2c83ff831 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Wed, 18 Mar 2020 22:33:49 +0800 Subject: [PATCH 14/25] Clear --- src/target/llvm/llvm_module.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 50cd3e7b1ad8..3f508a530c74 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -237,6 +237,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { if (tm_->getTargetTriple().isOSDarwin()) { module_->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); } + std::string verify_errors_storage; llvm::raw_string_ostream verify_errors(verify_errors_storage); LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors)) From b3e3da5e871faa755b0ecfaa69ae59051fca2a85 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 19 Mar 2020 03:03:20 +0800 Subject: [PATCH 15/25] Fix multiple instances of one IterVar --- src/tir/pass/rewrite_datatype.cc | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/tir/pass/rewrite_datatype.cc b/src/tir/pass/rewrite_datatype.cc index 9077408d704c..dbede80d281f 100644 --- a/src/tir/pass/rewrite_datatype.cc +++ b/src/tir/pass/rewrite_datatype.cc @@ -173,11 +173,14 @@ class DataTypeRewriter : public StmtExprMutator { op->attr_key == attr::virtual_thread) { Stmt s = StmtExprMutator::VisitStmt_(op); op = s.as(); - IterVar iv = Downcast(op->node); + const IterVarNode* iv = op->node.as(); PrimExpr e = VisitExpr(iv->var); Var var = Downcast(e); + if (ivmap_.find(iv) == ivmap_.end()) { + ivmap_[iv] = IterVarNode::make(iv->dom, var, iv->iter_type, iv->thread_tag); + } return AttrStmtNode::make( - IterVarNode::make(iv->dom, var, iv->iter_type, iv->thread_tag), + ivmap_[iv], op->attr_key, cast(var.dtype(), op->value), op->body); @@ -251,9 +254,12 @@ class DataTypeRewriter : public StmtExprMutator { private: // the internal visitor to deduce the narrowed dtype DataTypeVisitor visitor_; - // a map from Var before rewrite to Var after rewrite, + // a map from Var before rewrite to that after rewrite, // ensures one old Var maps to exactly one new Var std::unordered_map vmap_; + // a map from IterVar before rewrite to that after rewrite, + // ensures one old IterVar maps to exactly one new IterVar + std::unordered_map ivmap_; bool is_index_{false}; }; From dde45e207de29f7c84cfb859dd0a1f4e69240444 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 19 Mar 2020 18:52:13 +0800 Subject: [PATCH 16/25] Remove unnecessary cast --- src/tir/pass/rewrite_datatype.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/tir/pass/rewrite_datatype.cc b/src/tir/pass/rewrite_datatype.cc index dbede80d281f..f2418919ad73 100644 --- a/src/tir/pass/rewrite_datatype.cc +++ b/src/tir/pass/rewrite_datatype.cc @@ -144,6 +144,14 @@ class DataTypeRewriter : public StmtExprMutator { public: Stmt operator()(Stmt s) { visitor_(s); + for (auto i = visitor_.vmap.begin(), last = visitor_.vmap.end(); i != last;) { + PrimExpr e = GetRef(i->first); + if (e.dtype() == i->second) { + i = visitor_.vmap.erase(i); + } else { + ++i; + } + } return VisitStmt(s); } From ec815d333c175e3030c6c98f36780bd1d8f6823b Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 19 Mar 2020 18:54:10 +0800 Subject: [PATCH 17/25] Resolve comments --- src/tir/pass/rewrite_datatype.cc | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/src/tir/pass/rewrite_datatype.cc b/src/tir/pass/rewrite_datatype.cc index f2418919ad73..b1246edbfcd3 100644 --- a/src/tir/pass/rewrite_datatype.cc +++ b/src/tir/pass/rewrite_datatype.cc @@ -19,6 +19,7 @@ /*! * \file rewrite_datatype.cc + * \brief narrow the datatype of indexing vars */ #include @@ -57,8 +58,7 @@ class DataTypeVisitor final : public StmtExprVisitor { void VisitStmt_(const ForNode* op) { analyzer_.Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); - vset_.insert(op->loop_var.as()); - vextent_[op->loop_var.as()] = op->extent.dtype(); + vextent_[op->loop_var.as()] = op->extent.dtype(); return StmtExprVisitor::VisitStmt_(op); } @@ -69,8 +69,7 @@ class DataTypeVisitor final : public StmtExprVisitor { CHECK_NE(iv->thread_tag.length(), 0U); analyzer_.Bind(iv->var, Range::make_by_min_extent(0, op->value)); - vset_.insert(iv->var.as()); - vextent_[iv->var.as()] = op->value.dtype(); + vextent_[iv->var.as()] = op->value.dtype(); StmtExprVisitor::VisitStmt_(op); } else { StmtExprVisitor::VisitStmt_(op); @@ -81,15 +80,14 @@ class DataTypeVisitor final : public StmtExprVisitor { // Setup the domain information before simplification. for (const IterVar& iv : op->axis) { analyzer_.Bind(iv->var, iv->dom); - vset_.insert(iv->var.as()); - vextent_[iv->var.as()] = iv->dom->extent.dtype(); + vextent_[iv->var.as()] = iv->dom->extent.dtype(); } // Recursively call simplification when necessary. StmtExprVisitor::VisitExpr_(op); } void VisitExpr_(const VarNode* op) { - if (vset_.find(op) != vset_.end()) { + if (vextent_.find(op) != vextent_.end()) { int bits = std::min(vextent_[op].bits(), bits_); if (vmap.find(op) == vmap.end()) { vmap[op] = op->dtype.with_bits(bits); @@ -125,19 +123,17 @@ class DataTypeVisitor final : public StmtExprVisitor { } // the narrowed datatype of Var and IntImm - std::unordered_map vmap; + std::unordered_map vmap; protected: // internal analyzer arith::Analyzer analyzer_; private: - // the maximum bits of all containing expressions + // the maximum possible bit of the current expression's return dtype int bits_; - // the vars to be rewritten - std::unordered_set vset_; // the extent of vars to be rewritten - std::unordered_map vextent_; + std::unordered_map vextent_; }; class DataTypeRewriter : public StmtExprMutator { From 130373d764b852c09717af847d7235e2253827d6 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Fri, 20 Mar 2020 23:30:13 +0800 Subject: [PATCH 18/25] Resolve comments --- include/tvm/tir/ir_pass.h | 2 +- python/tvm/driver/build_module.py | 2 +- src/target/llvm/codegen_cpu.cc | 2 +- src/target/llvm/codegen_llvm.cc | 2 +- src/tir/pass/ffi_api.cc | 2 +- ...rewrite_datatype.cc => narrow_datatype.cc} | 41 ++++++++++++++++++- ...pe.py => test_tir_pass_narrow_datatype.py} | 2 +- 7 files changed, 45 insertions(+), 8 deletions(-) rename src/tir/pass/{rewrite_datatype.cc => narrow_datatype.cc} (88%) rename tests/python/unittest/{test_tir_pass_rewrite_datatype.py => test_tir_pass_narrow_datatype.py} (98%) diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h index 06ffa7534cf0..c341fc34b168 100644 --- a/include/tvm/tir/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -391,7 +391,7 @@ Stmt HoistIfThenElse(Stmt stmt); * \param stmt The stmt to do datatype rewrite * \return Transformed stmt. */ -Stmt DataTypeRewrite(Stmt stmt); +Stmt NarrowDataType(Stmt stmt); /*! * \brief Make an user callable API LoweredFunc. diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 220de8eb1282..7ef5565fbd4a 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -159,7 +159,7 @@ def lower(sch, # Phase 1 stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) - stmt = ir_pass.DataTypeRewrite(stmt) + stmt = ir_pass.NarrowDataType(stmt) stmt = ir_pass.CanonicalSimplify(stmt) for f in lower_phase1: stmt = f(stmt) diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 863ac950bc7f..70bcfe88c30e 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -943,7 +943,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { PrimExpr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent); CreateSerialFor(MakeValue(begin), MakeValue(end), - llvm::ConstantInt::getSigned(LLVMType(end.dtype()), 1), + llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body); } diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 4dfcdf97ff13..31465cd56bcb 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1121,7 +1121,7 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) { CHECK(op->for_type == ForType::Serial); } CreateSerialFor(MakeValue(op->min), MakeValue(op->extent), - llvm::ConstantInt::getSigned(LLVMType(op->extent.dtype()), 1), + llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body); } diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc index 2c711ae39b47..8172ac037250 100644 --- a/src/tir/pass/ffi_api.cc +++ b/src/tir/pass/ffi_api.cc @@ -165,6 +165,6 @@ REGISTER_PASS(InstrumentBoundCheckers); REGISTER_PASS(VerifyCompactBuffer); REGISTER_PASS(HoistIfThenElse); REGISTER_PASS(InferFragment) -REGISTER_PASS(DataTypeRewrite); +REGISTER_PASS(NarrowDataType); } // namespace tir } // namespace tvm diff --git a/src/tir/pass/rewrite_datatype.cc b/src/tir/pass/narrow_datatype.cc similarity index 88% rename from src/tir/pass/rewrite_datatype.cc rename to src/tir/pass/narrow_datatype.cc index b1246edbfcd3..bd7a8d40c128 100644 --- a/src/tir/pass/rewrite_datatype.cc +++ b/src/tir/pass/narrow_datatype.cc @@ -18,7 +18,7 @@ */ /*! - * \file rewrite_datatype.cc + * \file narrow_datatype.cc * \brief narrow the datatype of indexing vars */ @@ -30,6 +30,28 @@ namespace tvm { namespace tir { +// This pass narrows indexing expressions (like StoreNode::Index) +// that trivially fit into i32 to i32. Considering that i32 indices +// may be more efficient on some backends (while i64 may be more +// efficient on others, like llvm), we may want this pass when i32 +// indices are more efficient. +// +// For Var v, we determine its dtype by examining all the PrimExpr +// that contains v, denoted by E = {e_0 = v, e_1, e_2, ..., e_k}. +// If all expressions in E fit into i32, then we think v can be narrowed +// to i32. +// +// To make an indexing expression i32, we must make sure that every +// component of that expression is of dtype i32. So besides Var, we +// rewrite the following inside an indexing expression +// - Var +// - IntImm +// - Cast +// +// Algorithm: +// - Use DataTypeVisitor to determine whether a Var can be narrowed or not. +// - Use DataTypeRewritter to rewrite the components of an indexing expression. + using arith::Analyzer; using arith::IRMutatorWithAnalyzer; using arith::ConstIntBound; @@ -166,6 +188,9 @@ class DataTypeRewriter : public StmtExprMutator { Stmt VisitStmt_(const ForNode* op) final { Stmt s = StmtExprMutator::VisitStmt_(op); op = s.as(); + CHECK(op != nullptr) + << "Expected type to be ForNode" + << ", but get " << s->GetTypeKey(); PrimExpr e = VisitExpr(op->loop_var); Var var = Downcast(e); return ForNode::make(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), @@ -177,7 +202,13 @@ class DataTypeRewriter : public StmtExprMutator { op->attr_key == attr::virtual_thread) { Stmt s = StmtExprMutator::VisitStmt_(op); op = s.as(); + CHECK(op != nullptr) + << "Expected type to be AttrStmtNode" + << ", but get " << s->GetTypeKey(); const IterVarNode* iv = op->node.as(); + CHECK(iv != nullptr) + << "Expected type to be IterVarNode" + << ", but get " << op->node->GetTypeKey(); PrimExpr e = VisitExpr(iv->var); Var var = Downcast(e); if (ivmap_.find(iv) == ivmap_.end()) { @@ -233,6 +264,9 @@ class DataTypeRewriter : public StmtExprMutator { if (is_index_ && visitor_.vmap.find(op) != visitor_.vmap.end()) { PrimExpr e = StmtExprMutator::VisitExpr_(op); const CastNode* new_op = e.as(); + CHECK(new_op != nullptr) + << "Expected type to be CastNode" + << ", but get " << e->GetTypeKey(); return CastNode::make(visitor_.vmap[op], new_op->value); } return StmtExprMutator::VisitExpr_(op); @@ -298,6 +332,9 @@ DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=) PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) { PrimExpr e = StmtExprMutator::VisitExpr_(op); op = e.as(); + CHECK(op != nullptr) + << "Expected type to be CallNode" + << ", but get " << e->GetTypeKey(); if (op->call_type == CallNode::PureIntrinsic) { if (op->name == intrinsic::tvm_if_then_else) { return if_then_else(op->args[0], op->args[1], op->args[2]); @@ -318,7 +355,7 @@ PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) { return e; } -Stmt DataTypeRewrite(Stmt stmt) { +Stmt NarrowDataType(Stmt stmt) { return DataTypeRewriter()(stmt); } diff --git a/tests/python/unittest/test_tir_pass_rewrite_datatype.py b/tests/python/unittest/test_tir_pass_narrow_datatype.py similarity index 98% rename from tests/python/unittest/test_tir_pass_rewrite_datatype.py rename to tests/python/unittest/test_tir_pass_narrow_datatype.py index 69eee8cc19d6..297dc166aa6a 100644 --- a/tests/python/unittest/test_tir_pass_rewrite_datatype.py +++ b/tests/python/unittest/test_tir_pass_narrow_datatype.py @@ -32,7 +32,7 @@ def lower(sch, args): bounds = te.schedule.InferBound(sch) stmt = te.schedule.ScheduleOps(sch, bounds) stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False) - stmt = tvm.tir.ir_pass.DataTypeRewrite(stmt) + stmt = tvm.tir.ir_pass.NarrowDataType(stmt) return stmt From 04c4a10db252530c1c82facf6e97c056dc1e7625 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Sat, 21 Mar 2020 23:04:11 +0800 Subject: [PATCH 19/25] Resolve comments --- include/tvm/tir/ir_pass.h | 3 +- python/tvm/driver/build_module.py | 2 +- src/tir/pass/narrow_datatype.cc | 34 ++++++++++++------- .../unittest/test_tir_pass_narrow_datatype.py | 2 +- 4 files changed, 26 insertions(+), 15 deletions(-) diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h index c341fc34b168..39b39d28202e 100644 --- a/include/tvm/tir/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -389,9 +389,10 @@ Stmt HoistIfThenElse(Stmt stmt); /*! * \brief Narrow down PrimExpr datatype in stmt * \param stmt The stmt to do datatype rewrite + * \param target_bits the bit of target datatype * \return Transformed stmt. */ -Stmt NarrowDataType(Stmt stmt); +Stmt NarrowDataType(Stmt stmt, int target_bits); /*! * \brief Make an user callable API LoweredFunc. diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 7ef5565fbd4a..88231aaba56c 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -159,7 +159,7 @@ def lower(sch, # Phase 1 stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) - stmt = ir_pass.NarrowDataType(stmt) + stmt = ir_pass.NarrowDataType(stmt, 32) stmt = ir_pass.CanonicalSimplify(stmt) for f in lower_phase1: stmt = f(stmt) diff --git a/src/tir/pass/narrow_datatype.cc b/src/tir/pass/narrow_datatype.cc index bd7a8d40c128..863edea0199d 100644 --- a/src/tir/pass/narrow_datatype.cc +++ b/src/tir/pass/narrow_datatype.cc @@ -56,22 +56,25 @@ using arith::Analyzer; using arith::IRMutatorWithAnalyzer; using arith::ConstIntBound; -class DataTypeRewriter; - class DataTypeVisitor final : public StmtExprVisitor { public: + explicit DataTypeVisitor(int target_bits) + : bits_(target_bits), target_bits_(target_bits) {} + void VisitExpr(const PrimExpr& e) { if (e.dtype().is_int()) { - int bits = 64; - if (e.dtype().bits() <= 32 || - analyzer_.CanProve(e <= max_value(DataType::Int(32)) && - e >= min_value(DataType::Int(32)))) { - bits = 32; + int bits = max_bits_; + ConstIntBound bound = analyzer_.const_int_bound(e); + int64_t ubound = Downcast(max_value(DataType::Int(target_bits_)))->value; + int64_t lbound = Downcast(min_value(DataType::Int(target_bits_)))->value; + if (e.dtype().bits() <= target_bits_ || + (bound->max_value <= ubound && bound->min_value >= lbound)) { + bits = target_bits_; } - int tmp = bits_; - bits_ = bits > bits_ ? bits : bits_; + int tmp = bits > bits_ ? bits : bits_; + std::swap(bits_, tmp); StmtExprVisitor::VisitExpr(e); - bits_ = tmp; + std::swap(bits_, tmp); } else { StmtExprVisitor::VisitExpr(e); } @@ -152,14 +155,20 @@ class DataTypeVisitor final : public StmtExprVisitor { arith::Analyzer analyzer_; private: + // the maximum possible bits, which serves as an init value + static constexpr const int max_bits_ = 64; // the maximum possible bit of the current expression's return dtype int bits_; + // the target bits + int target_bits_; // the extent of vars to be rewritten std::unordered_map vextent_; }; class DataTypeRewriter : public StmtExprMutator { public: + explicit DataTypeRewriter(int target_bits): visitor_(target_bits) {} + Stmt operator()(Stmt s) { visitor_(s); for (auto i = visitor_.vmap.begin(), last = visitor_.vmap.end(); i != last;) { @@ -298,6 +307,7 @@ class DataTypeRewriter : public StmtExprMutator { // a map from IterVar before rewrite to that after rewrite, // ensures one old IterVar maps to exactly one new IterVar std::unordered_map ivmap_; + // indicator of LoadNode::index and StoreNode::index bool is_index_{false}; }; @@ -355,8 +365,8 @@ PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) { return e; } -Stmt NarrowDataType(Stmt stmt) { - return DataTypeRewriter()(stmt); +Stmt NarrowDataType(Stmt stmt, int target_bits) { + return DataTypeRewriter(target_bits)(stmt); } } // namespace tir diff --git a/tests/python/unittest/test_tir_pass_narrow_datatype.py b/tests/python/unittest/test_tir_pass_narrow_datatype.py index 297dc166aa6a..539decc54053 100644 --- a/tests/python/unittest/test_tir_pass_narrow_datatype.py +++ b/tests/python/unittest/test_tir_pass_narrow_datatype.py @@ -32,7 +32,7 @@ def lower(sch, args): bounds = te.schedule.InferBound(sch) stmt = te.schedule.ScheduleOps(sch, bounds) stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False) - stmt = tvm.tir.ir_pass.NarrowDataType(stmt) + stmt = tvm.tir.ir_pass.NarrowDataType(stmt, 32) return stmt From 0578b85db141b5df92309ceadd5df185a7cdec67 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 30 Mar 2020 00:44:55 +0800 Subject: [PATCH 20/25] Use ir_builder to test --- python/tvm/tir/ir_builder.py | 6 +- src/arith/const_int_bound.cc | 10 ++ .../unittest/test_tir_pass_narrow_datatype.py | 143 ++++++++++-------- 3 files changed, 90 insertions(+), 69 deletions(-) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 885b8475082e..0c4c36888eb5 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -76,7 +76,8 @@ def dtype(self): def __getitem__(self, index): t = DataType(self._content_type) if t.lanes > 1: - index = _expr.Ramp(index * t.lanes, 1, t.lanes) + base = index * t.lanes + index = _expr.Ramp(base, const(1, base.dtype), t.lanes) return _expr.Load(self._content_type, self._buffer_var, index) def __setitem__(self, index, value): @@ -87,7 +88,8 @@ def __setitem__(self, index, value): value.dtype, self._content_type)) t = DataType(self._content_type) if t.lanes > 1: - index = _expr.Ramp(index * t.lanes, 1, t.lanes) + base = index * t.lanes + index = _expr.Ramp(base, const(1, base.dtype), t.lanes) self._builder.emit(_stmt.Store(self._buffer_var, value, index)) diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 9ef5723e153e..443b48e048f0 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -148,6 +148,16 @@ class ConstIntBoundAnalyzer::Impl : return res; } + Entry VisitExpr_(const RampNode* op) final { + // op = {base + i * stride | 0 <= i < lanes} + // Entry(op) = Union(Entry(base + i * stride) | 0 <= i < lanes) + // Note that `base + i * stride` is linear w.r.t. `i` + // Entry(op) = Union(Entry(base + i * stride) | i = 0, i = lanes-1) + Entry a = VisitExpr(op->base); + Entry b = VisitExpr(op->base + (op->lanes - 1) * op->stride); + return Union(a, b); + } + Entry VisitExpr_(const CastNode* op) final { Entry a = VisitExpr(op->value); Entry b = Everything(op->dtype); diff --git a/tests/python/unittest/test_tir_pass_narrow_datatype.py b/tests/python/unittest/test_tir_pass_narrow_datatype.py index 539decc54053..fc91f6356164 100644 --- a/tests/python/unittest/test_tir_pass_narrow_datatype.py +++ b/tests/python/unittest/test_tir_pass_narrow_datatype.py @@ -36,77 +36,87 @@ def lower(sch, args): return stmt -def test_const(): - def test(m, n, dtype): - A = te.placeholder((m, n), name='A') - B = te.placeholder((m, n), name='B') - C = te.compute((m, n), lambda *idx: A[idx] + B[idx]) - s = te.create_schedule(C.op) - stmt = lower(s, [A, B, C]) - assert stmt.body.loop_var.dtype == dtype - assert stmt.body.body.loop_var.dtype == dtype - - test(2, 2, "int32") - # i32 + i32 is not promoted to i64 even in the case of overflow - test(2**16, 2**16, "int32") - test(tvm.tir.const(2, dtype='int64'), tvm.tir.const(2, dtype='int64'), "int32") - test(tvm.tir.const(2**16, dtype='int64'), tvm.tir.const(2**16, dtype='int64'), "int64") - - -def test_symbolic(): - def test(m, n, dtype): - A = te.placeholder((m, n), name='A') - B = te.placeholder((m, n), name='B') - C = te.compute((m, n), lambda *idx: A[idx] + B[idx]) - s = te.create_schedule(C.op) - stmt = lower(s, [A, B, C]) - assert stmt.body.loop_var.dtype == dtype - assert stmt.body.body.loop_var.dtype == dtype - - test(te.size_var(name='m', dtype='int32'), te.size_var(name='n', dtype='int32'), "int32") - test(te.size_var(name='m', dtype='int64'), te.size_var(name='n', dtype='int64'), "int64") +def test_basic(): + def check(m, n, target_bits, target_dtype): + ib = tvm.tir.ir_builder.create() + Ab = tvm.tir.decl_buffer((m, n), name='A') + A = ib.buffer_ptr(Ab) + Bb = tvm.tir.decl_buffer((m, n), name='B') + B = ib.buffer_ptr(Bb) + with ib.for_range(0, m, name='i') as i: + with ib.for_range(0, n, name='j') as j: + B[i * n + j] = A[i * n + j] + 1 + stmt = ib.get() + stmt = tvm.tir.ir_pass.NarrowDataType(stmt, target_bits) + assert stmt.loop_var.dtype == target_dtype + assert stmt.body.loop_var.dtype == target_dtype + + # const shape + check(2, 2, 32, "int32") + check(2**16, 2**16, 32, "int32") # i32 + i32 is not promoted to i64 even if overflow + check(tvm.tir.const(2, dtype='int64'), tvm.tir.const(2, dtype='int64'), 32, "int32") + check(tvm.tir.const(2**16, dtype='int64'), tvm.tir.const(2**16, dtype='int64'), 32, "int64") + # symbolic shape + check(te.size_var(name='m', dtype='int32'), te.size_var(name='n', dtype='int32'), 32, "int32") + check(te.size_var(name='m', dtype='int64'), te.size_var(name='n', dtype='int64'), 32, "int64") def test_thread_axis(): - def test(m, n, k, dtype): - A = te.placeholder((m, n, k), name='A') - B = te.placeholder((m, n, k), name='B') - C = te.compute((m, n, k), lambda *idx: A[idx] + B[idx]) - s = te.create_schedule(C.op) - fused = s[C].fuse(*[axis for axis in C.op.axis]) - xo, xi = s[C].split(fused, factor=32) - s[C].bind(xo, te.thread_axis("blockIdx.x")) - s[C].bind(xi, te.thread_axis("threadIdx.x")) - stmt = lower(s, [A, B, C]) - - test(2, 2, 2, dtype='int32') + def check(m, n, target_bits, target_dtype): + ib = tvm.tir.ir_builder.create() + Ab = tvm.tir.decl_buffer((m, n), name='A') + A = ib.buffer_ptr(Ab) + Bb = tvm.tir.decl_buffer((m, n), name='B') + B = ib.buffer_ptr(Bb) + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(bx, "thread_extent", m) + ib.scope_attr(tx, "thread_extent", n) + B[bx * n + tx] = A[bx * n + tx] + 1 + stmt = ib.get() + stmt = tvm.tir.ir_pass.NarrowDataType(stmt, target_bits) + assert stmt.node.var.dtype == target_dtype + assert stmt.body.node.var.dtype == target_dtype + + + check(2, 32, + target_bits=32, target_dtype='int32') # i32 + i32 is not promoted to i64 even in the case of overflow - test(2**10, 2**11, 2**12, dtype='int32') - test(tvm.tir.const(2, dtype='int64'), - tvm.tir.const(2, dtype='int64'), - tvm.tir.const(2, dtype='int64'), - dtype='int32') - test(tvm.tir.const(2**10, dtype='int64'), - tvm.tir.const(2**11, dtype='int64'), - tvm.tir.const(2**12, dtype='int64'), - dtype='int64') + check(2**30, 32, + target_bits=32, target_dtype='int32') + check(tvm.tir.const(2, dtype='int64'), + tvm.tir.const(32, dtype='int64'), + target_bits=32, target_dtype='int32') + check(tvm.tir.const(2**30, dtype='int64'), + tvm.tir.const(32, dtype='int64'), + target_bits=32, target_dtype='int64') def test_multilanes(): - def test(m, lanes, dtype): - A = te.placeholder((m,), name='A', dtype='float32x{}'.format(lanes)) - B = te.placeholder((m,), name='B', dtype='float32x{}'.format(lanes)) - C = te.compute((m,), lambda *idx: A[idx] + B[idx]) - s = te.create_schedule(C.op) - stmt = lower(s, [A, B, C]) - assert stmt.body.loop_var.dtype == dtype - - test(tvm.tir.const(64, dtype='int32'), 2, 'int32') - test(tvm.tir.const(2 ** 32, dtype='int64'), 2, 'int64') + def check(m, lanes, target_bits, target_dtype): + ib = tvm.tir.ir_builder.create() + Ab = tvm.tir.decl_buffer((m,), dtype='float32x{}'.format(lanes), name='A') + A = ib.buffer_ptr(Ab) + Bb = tvm.tir.decl_buffer((m,), dtype='float32x{}'.format(lanes), name='B') + B = ib.buffer_ptr(Bb) + with ib.for_range(0, m, name='i', dtype=m.dtype) as i: + B[i] = A[i] + 1 + stmt = ib.get() + stmt = tvm.tir.ir_pass.NarrowDataType(stmt, target_bits) + assert stmt.loop_var.dtype == target_dtype + + check(tvm.tir.const(2 ** 10, dtype='int32'), 2, + target_bits=32, target_dtype='int32') + check(tvm.tir.const(2 ** 32, dtype='int32'), 2, + target_bits=32, target_dtype='int32') + check(tvm.tir.const(2 ** 10, dtype='int64'), 2, + target_bits=32, target_dtype='int32') + check(tvm.tir.const(2 ** 32, dtype='int64'), 2, + target_bits=32, target_dtype='int64') def test_reduce(): - def test(m, dtype): + def check(m, dtype): A = te.placeholder((m,), name='A', dtype='float32') k = te.reduce_axis((0, m), "k") B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name='B') @@ -114,15 +124,14 @@ def test(m, dtype): stmt = lower(s, [A, B]) assert stmt.body[1].loop_var.dtype == dtype - test(tvm.tir.const(64, dtype='int32'), 'int32') - test(tvm.tir.const(64, dtype='int64'), 'int32') - test(te.var('n', dtype='int32'), 'int32') - test(te.var('n', dtype='int64'), 'int64') + check(tvm.tir.const(64, dtype='int32'), 'int32') + check(tvm.tir.const(64, dtype='int64'), 'int32') + check(te.var('n', dtype='int32'), 'int32') + check(te.var('n', dtype='int64'), 'int64') if __name__ == "__main__": - test_const() - test_symbolic() + test_basic() test_thread_axis() test_multilanes() test_reduce() From 9c5acee1d86596d1244a749f7c9f2cee108fc3f7 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 30 Mar 2020 18:37:42 +0800 Subject: [PATCH 21/25] Add comments --- src/tir/pass/narrow_datatype.cc | 40 +++++--- .../unittest/test_tir_pass_narrow_datatype.py | 93 ++++++++++++++----- 2 files changed, 100 insertions(+), 33 deletions(-) diff --git a/src/tir/pass/narrow_datatype.cc b/src/tir/pass/narrow_datatype.cc index 863edea0199d..d376e41d1e64 100644 --- a/src/tir/pass/narrow_datatype.cc +++ b/src/tir/pass/narrow_datatype.cc @@ -31,18 +31,19 @@ namespace tvm { namespace tir { // This pass narrows indexing expressions (like StoreNode::Index) -// that trivially fit into i32 to i32. Considering that i32 indices -// may be more efficient on some backends (while i64 may be more -// efficient on others, like llvm), we may want this pass when i32 +// that trivially fit into i32/i16 (denoted by `target_bits_`) to +// i32/i16. Considering that i32/i16 indices may be more +// efficient on some backends (while i64 may be more efficient +// on others, like llvm), we may want this pass when i32/i16 // indices are more efficient. // // For Var v, we determine its dtype by examining all the PrimExpr // that contains v, denoted by E = {e_0 = v, e_1, e_2, ..., e_k}. -// If all expressions in E fit into i32, then we think v can be narrowed -// to i32. +// If all expressions in E fit into i32/i16, then we think v can be narrowed +// to i32/i16. // -// To make an indexing expression i32, we must make sure that every -// component of that expression is of dtype i32. So besides Var, we +// To make an indexing expression i32/i16, we must make sure that every +// component of that expression is of dtype i32/i16. So besides Var, we // rewrite the following inside an indexing expression // - Var // - IntImm @@ -56,6 +57,16 @@ using arith::Analyzer; using arith::IRMutatorWithAnalyzer; using arith::ConstIntBound; +// Determine the result dtype for Var, IntImm and Cast, +// which will be stored in `vmap` eventually. +// +// Algorithm: +// We propogate the dtypes of all the Exprs that contain Var `var` into `vmap[var]`. +// To be more specific, if for each Expr `e` which contains `var` +// (`var` is a child node of `e` in AST), `e` fits into `target_bits_`, +// then we narrow `var` into `target_bits_`. That is, +// `vmap[var] = min(target_bits_, var.dtype.bits())` +// Otherwise, `var` is not narrowed, that is, `vmap[var] = var.dtype.bits()` class DataTypeVisitor final : public StmtExprVisitor { public: explicit DataTypeVisitor(int target_bits) @@ -65,8 +76,8 @@ class DataTypeVisitor final : public StmtExprVisitor { if (e.dtype().is_int()) { int bits = max_bits_; ConstIntBound bound = analyzer_.const_int_bound(e); - int64_t ubound = Downcast(max_value(DataType::Int(target_bits_)))->value; - int64_t lbound = Downcast(min_value(DataType::Int(target_bits_)))->value; + int64_t ubound = Downcast(max_value(DataType::Int(target_bits_)))->value; + int64_t lbound = Downcast(min_value(DataType::Int(target_bits_)))->value; if (e.dtype().bits() <= target_bits_ || (bound->max_value <= ubound && bound->min_value >= lbound)) { bits = target_bits_; @@ -113,10 +124,13 @@ class DataTypeVisitor final : public StmtExprVisitor { void VisitExpr_(const VarNode* op) { if (vextent_.find(op) != vextent_.end()) { + // We only narrow and never promote, so the result dtype + // is upperbounded by its original dtype before rewrite. int bits = std::min(vextent_[op].bits(), bits_); if (vmap.find(op) == vmap.end()) { vmap[op] = op->dtype.with_bits(bits); } else { + // We take maximum bits for all the possible Expr where a var occurs vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits)); } } @@ -125,6 +139,8 @@ class DataTypeVisitor final : public StmtExprVisitor { void VisitExpr_(const IntImmNode* op) { if (op->dtype.is_int()) { + // We only narrow and never promote, so the result dtype + // is upperbounded by its original dtype before rewrite. int bits = std::min(op->dtype.bits(), bits_); if (vmap.find(op) == vmap.end()) { vmap[op] = op->dtype.with_bits(bits); @@ -137,6 +153,8 @@ class DataTypeVisitor final : public StmtExprVisitor { void VisitExpr_(const CastNode* op) { if (op->dtype.is_int()) { + // We only narrow and never promote, so the result dtype + // is upperbounded by its original dtype before rewrite. int bits = std::min(op->dtype.bits(), bits_); if (vmap.find(op) == vmap.end()) { vmap[op] = op->dtype.with_bits(bits); @@ -201,7 +219,7 @@ class DataTypeRewriter : public StmtExprMutator { << "Expected type to be ForNode" << ", but get " << s->GetTypeKey(); PrimExpr e = VisitExpr(op->loop_var); - Var var = Downcast(e); + Var var = Downcast(e); return ForNode::make(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), op->for_type, op->device_api, op->body); } @@ -219,7 +237,7 @@ class DataTypeRewriter : public StmtExprMutator { << "Expected type to be IterVarNode" << ", but get " << op->node->GetTypeKey(); PrimExpr e = VisitExpr(iv->var); - Var var = Downcast(e); + Var var = Downcast(e); if (ivmap_.find(iv) == ivmap_.end()) { ivmap_[iv] = IterVarNode::make(iv->dom, var, iv->iter_type, iv->thread_tag); } diff --git a/tests/python/unittest/test_tir_pass_narrow_datatype.py b/tests/python/unittest/test_tir_pass_narrow_datatype.py index fc91f6356164..85b3f63261dd 100644 --- a/tests/python/unittest/test_tir_pass_narrow_datatype.py +++ b/tests/python/unittest/test_tir_pass_narrow_datatype.py @@ -16,9 +16,10 @@ # under the License. import tvm from tvm import te +from tvm.tir import const -def lower(sch, args): +def lower(sch, args, target_bits): binds = {} arg_list = [] for x in args: @@ -32,7 +33,7 @@ def lower(sch, args): bounds = te.schedule.InferBound(sch) stmt = te.schedule.ScheduleOps(sch, bounds) stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False) - stmt = tvm.tir.ir_pass.NarrowDataType(stmt, 32) + stmt = tvm.tir.ir_pass.NarrowDataType(stmt, target_bits) return stmt @@ -52,10 +53,16 @@ def check(m, n, target_bits, target_dtype): assert stmt.body.loop_var.dtype == target_dtype # const shape + # i32 -> i32 check(2, 2, 32, "int32") check(2**16, 2**16, 32, "int32") # i32 + i32 is not promoted to i64 even if overflow - check(tvm.tir.const(2, dtype='int64'), tvm.tir.const(2, dtype='int64'), 32, "int32") - check(tvm.tir.const(2**16, dtype='int64'), tvm.tir.const(2**16, dtype='int64'), 32, "int64") + # i64 -> i32 + check(const(2, dtype='int64'), const(2, dtype='int64'), 32, "int32") + check(const(2**16, dtype='int64'), const(2**16, dtype='int64'), 32, "int64") + # i32 -> i16 + check(2, 2, 16, "int16") + check(2**10, 2**10, 16, "int32") + # symbolic shape check(te.size_var(name='m', dtype='int32'), te.size_var(name='n', dtype='int32'), 32, "int32") check(te.size_var(name='m', dtype='int64'), te.size_var(name='n', dtype='int64'), 32, "int64") @@ -78,18 +85,23 @@ def check(m, n, target_bits, target_dtype): assert stmt.node.var.dtype == target_dtype assert stmt.body.node.var.dtype == target_dtype - + # i32 -> i32 check(2, 32, target_bits=32, target_dtype='int32') - # i32 + i32 is not promoted to i64 even in the case of overflow - check(2**30, 32, + check(2**30, 32, # i32 + i32 is not promoted to i64 even in the case of overflow target_bits=32, target_dtype='int32') - check(tvm.tir.const(2, dtype='int64'), - tvm.tir.const(32, dtype='int64'), + # i64 -> i32 + check(const(2, dtype='int64'), + const(32, dtype='int64'), target_bits=32, target_dtype='int32') - check(tvm.tir.const(2**30, dtype='int64'), - tvm.tir.const(32, dtype='int64'), + check(const(2**30, dtype='int64'), + const(32, dtype='int64'), target_bits=32, target_dtype='int64') + # i32 -> i16 + check(2, 32, + target_bits=16, target_dtype='int16') + check(2**14, 32, + target_bits=16, target_dtype='int32') def test_multilanes(): @@ -105,29 +117,65 @@ def check(m, lanes, target_bits, target_dtype): stmt = tvm.tir.ir_pass.NarrowDataType(stmt, target_bits) assert stmt.loop_var.dtype == target_dtype - check(tvm.tir.const(2 ** 10, dtype='int32'), 2, + # i32 -> i32 + check(const(2 ** 10, dtype='int32'), 2, target_bits=32, target_dtype='int32') - check(tvm.tir.const(2 ** 32, dtype='int32'), 2, + check(const(2 ** 32, dtype='int32'), 2, target_bits=32, target_dtype='int32') - check(tvm.tir.const(2 ** 10, dtype='int64'), 2, + # i64 -> i32 + check(const(2 ** 10, dtype='int64'), 2, target_bits=32, target_dtype='int32') - check(tvm.tir.const(2 ** 32, dtype='int64'), 2, + check(const(2 ** 32, dtype='int64'), 2, target_bits=32, target_dtype='int64') + # i32 -> i16 + check(const(2 ** 10, dtype='int32'), 2, + target_bits=16, target_dtype='int16') + check(const(2 ** 16, dtype='int32'), 2, + target_bits=16, target_dtype='int32') def test_reduce(): - def check(m, dtype): + def check(m, target_bits, target_dtype): A = te.placeholder((m,), name='A', dtype='float32') k = te.reduce_axis((0, m), "k") B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name='B') s = te.create_schedule(B.op) - stmt = lower(s, [A, B]) - assert stmt.body[1].loop_var.dtype == dtype + stmt = lower(s, [A, B], target_bits) + assert stmt.body[1].loop_var.dtype == target_dtype + + # i32 -> i32 + check(const(64, dtype='int32'), 32, 'int32') + # i64 -> i32 + check(const(64, dtype='int64'), 32, 'int32') + # i32 -> i16 + check(const(64, dtype='int32'), 16, 'int16') + check(const(2**16, dtype='int32'), 16, 'int32') + # symbolic + check(te.var('n', dtype='int32'), 32, 'int32') + check(te.var('n', dtype='int64'), 32, 'int64') + + +def test_slice(): + def check(m, n, target_bits, target_dtype): + ib = tvm.tir.ir_builder.create() + Ab = tvm.tir.decl_buffer((m, n), name='A') + A = ib.buffer_ptr(Ab) + Bb = tvm.tir.decl_buffer((m, n * 2), name='B') + B = ib.buffer_ptr(Bb) + with ib.for_range(0, m, name='i') as i: + with ib.for_range(0, n, name='j') as j: + A[i * n + j] = B[i * 2 * n + 2 * j] + 1 + stmt = ib.get() + stmt = tvm.tir.ir_pass.NarrowDataType(stmt, target_bits) + assert stmt.loop_var.dtype == target_dtype + assert stmt.body.loop_var.dtype == target_dtype - check(tvm.tir.const(64, dtype='int32'), 'int32') - check(tvm.tir.const(64, dtype='int64'), 'int32') - check(te.var('n', dtype='int32'), 'int32') - check(te.var('n', dtype='int64'), 'int64') + # The maximum index is (2**15 * 2**15 - 1) * 2 <= 2**31 - 1 + check(const(2**15, 'int64'), const(2**15, 'int64'), + target_bits=32, target_dtype='int32') + # The maximum index is (2**15 * 2**15 - 1 + 2**15) * 2 > 2**31 - 1 + check(const(2**15, 'int64'), const((2**15 + 1), 'int64'), + target_bits=32, target_dtype='int64') if __name__ == "__main__": @@ -135,3 +183,4 @@ def check(m, dtype): test_thread_axis() test_multilanes() test_reduce() + test_slice() From 6cee7b760808c1dd2edc8f412fde488f157fc3c0 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 30 Mar 2020 19:29:44 +0800 Subject: [PATCH 22/25] Fix sanity --- python/tvm/tir/expr.py | 3 ++- src/tir/pass/narrow_datatype.cc | 4 ++-- tests/python/unittest/test_tir_pass_narrow_datatype.py | 1 + 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index deb8d3446fc1..a192fce6439a 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -370,7 +370,8 @@ def __init__(self, dom, var, iter_type, thread_tag=""): raise TypeError("dom need to be Range") name = var if var is not None else "iter" - var = Var(name, dtype="int32") if not isinstance(var, Var) else var + dtype = "int32" if dom is None else dom.extent.dtype + var = Var(name, dtype=dtype) if not isinstance(var, Var) else var self.__init_handle_by_constructor__( _ffi_api.IterVar, dom, var, iter_type, thread_tag) diff --git a/src/tir/pass/narrow_datatype.cc b/src/tir/pass/narrow_datatype.cc index d376e41d1e64..e6bc43abf7cb 100644 --- a/src/tir/pass/narrow_datatype.cc +++ b/src/tir/pass/narrow_datatype.cc @@ -31,7 +31,7 @@ namespace tvm { namespace tir { // This pass narrows indexing expressions (like StoreNode::Index) -// that trivially fit into i32/i16 (denoted by `target_bits_`) to +// that trivially fit into i32/i16 (denoted by `target_bits_`) to // i32/i16. Considering that i32/i16 indices may be more // efficient on some backends (while i64 may be more efficient // on others, like llvm), we may want this pass when i32/i16 @@ -62,7 +62,7 @@ using arith::ConstIntBound; // // Algorithm: // We propogate the dtypes of all the Exprs that contain Var `var` into `vmap[var]`. -// To be more specific, if for each Expr `e` which contains `var` +// To be more specific, if for each Expr `e` which contains `var` // (`var` is a child node of `e` in AST), `e` fits into `target_bits_`, // then we narrow `var` into `target_bits_`. That is, // `vmap[var] = min(target_bits_, var.dtype.bits())` diff --git a/tests/python/unittest/test_tir_pass_narrow_datatype.py b/tests/python/unittest/test_tir_pass_narrow_datatype.py index 85b3f63261dd..96a8ef41f793 100644 --- a/tests/python/unittest/test_tir_pass_narrow_datatype.py +++ b/tests/python/unittest/test_tir_pass_narrow_datatype.py @@ -157,6 +157,7 @@ def check(m, target_bits, target_dtype): def test_slice(): def check(m, n, target_bits, target_dtype): + # The index may overflow in B, while not in A ib = tvm.tir.ir_builder.create() Ab = tvm.tir.decl_buffer((m, n), name='A') A = ib.buffer_ptr(Ab) From 458f0b3e4ee737422447f17c84b46d5c682a9155 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 2 Apr 2020 14:52:59 +0800 Subject: [PATCH 23/25] Migrate to transform pass --- include/tvm/tir/ir_pass.h | 3 ++- include/tvm/tir/transform.h | 10 ++++++++ python/tvm/tir/transform/transform.py | 15 ++++++++++++ .../{pass => transforms}/narrow_datatype.cc | 22 ++++++++++++++++++ ... => test_tir_transform_narrow_datatype.py} | 23 ++++++++++++------- 5 files changed, 64 insertions(+), 9 deletions(-) rename src/tir/{pass => transforms}/narrow_datatype.cc (95%) rename tests/python/unittest/{test_tir_pass_narrow_datatype.py => test_tir_transform_narrow_datatype.py} (91%) diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h index 39b39d28202e..d63eac5a7065 100644 --- a/include/tvm/tir/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -387,7 +387,8 @@ Stmt DecorateDeviceScope(Stmt stmt); Stmt HoistIfThenElse(Stmt stmt); /*! - * \brief Narrow down PrimExpr datatype in stmt + * \brief Narrow down PrimExpr datatype in stmt to target_bits. + * \note Run this pass after StorageFlatten. * \param stmt The stmt to do datatype rewrite * \param target_bits the bit of target datatype * \return Transformed stmt. diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 9d55db571016..a414bccfafa1 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -87,6 +87,16 @@ TVM_DLL Pass LowerDeviceStorageAccessInfo(); */ TVM_DLL Pass LowerWarpMemory(); + +/*! + * \brief Narrow down PrimExpr datatype in stmt to target_bits. + * + * \note Run this pass after StorageFlatten. + * + * \return The pass. + */ +TVM_DLL Pass NarrowDataType(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 2b50387ff5b8..7c2b3c812714 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -66,3 +66,18 @@ def LowerWarpMemory(): The result pass """ return _ffi_api.LowerWarpMemory() + + +def NarrowDataType(): + """Narrow down PrimExpr datatype in stmt to target_bits. + + Returns + ------- + fpass : tvm.ir.transform.Pass + The result pass + + Note + ---- + Run this pass after StorageFlatten. + """ + return _ffi_api.NarrowDataType() diff --git a/src/tir/pass/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc similarity index 95% rename from src/tir/pass/narrow_datatype.cc rename to src/tir/transforms/narrow_datatype.cc index e6bc43abf7cb..03741fa9408a 100644 --- a/src/tir/pass/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -24,6 +24,8 @@ #include #include +#include +#include #include "../../arith/ir_mutator_with_analyzer.h" #include "../../arith/ir_visitor_with_analyzer.h" @@ -387,5 +389,25 @@ Stmt NarrowDataType(Stmt stmt, int target_bits) { return DataTypeRewriter(target_bits)(stmt); } +namespace transform { + +Pass NarrowDataType() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + // TODO(@hzfan): should Target be Attr here, with target_bits inferred from it? + IntImm target_bits = f->GetAttr("target_bits"); + CHECK(target_bits.defined()) + << "NarrowDataType: Require the target_bits"; + n->body = DataTypeRewriter(target_bits->value)(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass( + pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType") +.set_body_typed(NarrowDataType); + +} // namespace transform } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_pass_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py similarity index 91% rename from tests/python/unittest/test_tir_pass_narrow_datatype.py rename to tests/python/unittest/test_tir_transform_narrow_datatype.py index 96a8ef41f793..49df1c22033f 100644 --- a/tests/python/unittest/test_tir_pass_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -19,7 +19,15 @@ from tvm.tir import const -def lower(sch, args, target_bits): +def lower_stmt(params, stmt, target_bits): + func = tvm.tir.PrimFunc(params, stmt).with_attr( + "target_bits", target_bits) + func = tvm.tir.transform.NarrowDataType()(tvm.IRModule.from_expr(func))["main"] + stmt = func.body + return stmt + + +def lower_sch(sch, args, target_bits): binds = {} arg_list = [] for x in args: @@ -33,8 +41,7 @@ def lower(sch, args, target_bits): bounds = te.schedule.InferBound(sch) stmt = te.schedule.ScheduleOps(sch, bounds) stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False) - stmt = tvm.tir.ir_pass.NarrowDataType(stmt, target_bits) - return stmt + return lower_stmt(arg_list, stmt, target_bits) def test_basic(): @@ -48,7 +55,7 @@ def check(m, n, target_bits, target_dtype): with ib.for_range(0, n, name='j') as j: B[i * n + j] = A[i * n + j] + 1 stmt = ib.get() - stmt = tvm.tir.ir_pass.NarrowDataType(stmt, target_bits) + stmt = lower_stmt([Ab, Bb], stmt, target_bits) assert stmt.loop_var.dtype == target_dtype assert stmt.body.loop_var.dtype == target_dtype @@ -81,7 +88,7 @@ def check(m, n, target_bits, target_dtype): ib.scope_attr(tx, "thread_extent", n) B[bx * n + tx] = A[bx * n + tx] + 1 stmt = ib.get() - stmt = tvm.tir.ir_pass.NarrowDataType(stmt, target_bits) + stmt = lower_stmt([Ab, Bb], stmt, target_bits) assert stmt.node.var.dtype == target_dtype assert stmt.body.node.var.dtype == target_dtype @@ -114,7 +121,7 @@ def check(m, lanes, target_bits, target_dtype): with ib.for_range(0, m, name='i', dtype=m.dtype) as i: B[i] = A[i] + 1 stmt = ib.get() - stmt = tvm.tir.ir_pass.NarrowDataType(stmt, target_bits) + stmt = lower_stmt([Ab, Bb], stmt, target_bits) assert stmt.loop_var.dtype == target_dtype # i32 -> i32 @@ -140,7 +147,7 @@ def check(m, target_bits, target_dtype): k = te.reduce_axis((0, m), "k") B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name='B') s = te.create_schedule(B.op) - stmt = lower(s, [A, B], target_bits) + stmt = lower_sch(s, [A, B], target_bits) assert stmt.body[1].loop_var.dtype == target_dtype # i32 -> i32 @@ -167,7 +174,7 @@ def check(m, n, target_bits, target_dtype): with ib.for_range(0, n, name='j') as j: A[i * n + j] = B[i * 2 * n + 2 * j] + 1 stmt = ib.get() - stmt = tvm.tir.ir_pass.NarrowDataType(stmt, target_bits) + stmt = lower_stmt([Ab, Bb], stmt, target_bits) assert stmt.loop_var.dtype == target_dtype assert stmt.body.loop_var.dtype == target_dtype From e96044937c8426fd0ec5e1cd794b1f71c08ca0f1 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 2 Apr 2020 17:27:05 +0800 Subject: [PATCH 24/25] ConstIntBound with memorization --- include/tvm/arith/analyzer.h | 8 ++++++++ src/arith/const_int_bound.cc | 22 ++++++++++++++++++++++ src/tir/transforms/narrow_datatype.cc | 9 +++++++-- 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index e7f5ede22995..4a85bb320b72 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -114,6 +114,14 @@ class ConstIntBoundAnalyzer { */ ConstIntBound operator()(const PrimExpr& expr); + /*! + * \brief analyze the expr with the intermediate memorized to avoid redundant computation + * \param expr The expression of interest. + * \return the result of the analysis. + */ + ConstIntBound operator()(const PrimExpr& expr, + std::unordered_map* bound); + /*! * \brief Update constant int bound information of var. * diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 443b48e048f0..e26f4c8e6dd2 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -145,6 +145,17 @@ class ConstIntBoundAnalyzer::Impl : res = Intersect(res, info.bound); } } + if (bound_) { + const PrimExprNode* op = expr.as(); + auto val = bound_->find(op); + if (val != bound_->end()) { + CHECK(val->second->min_value == res.min_value && + val->second->max_value == res.max_value) + << "Detected bound for " << expr + << "conflicts with memorization"; + } + (*bound_)[op] = ConstIntBound(res.min_value, res.max_value); + } return res; } @@ -349,10 +360,13 @@ class ConstIntBoundAnalyzer::Impl : } private: + friend class ConstIntBoundAnalyzer; // internal variable map std::unordered_map var_map_; // additional bound info std::vector additional_info_; + // look up table for memorization + std::unordered_map* bound_{nullptr}; // constants: the limit value means umlimited // NOTE: kNegInf/kPosInf are used to represent infinity. static const constexpr int64_t kNegInf = ConstIntBound::kNegInf; @@ -545,6 +559,14 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) { return ConstIntBound(ret.min_value, ret.max_value); } +ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr, + std::unordered_map* bound) { + impl_->bound_ = bound; + Entry ret = impl_->VisitExpr(expr); + impl_->bound_ = nullptr; + return ConstIntBound(ret.min_value, ret.max_value); +} + void ConstIntBoundAnalyzer::Update(const Var& var, const ConstIntBound& info, bool override) { diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 03741fa9408a..00bc45aec052 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -77,7 +77,11 @@ class DataTypeVisitor final : public StmtExprVisitor { void VisitExpr(const PrimExpr& e) { if (e.dtype().is_int()) { int bits = max_bits_; - ConstIntBound bound = analyzer_.const_int_bound(e); + const PrimExprNode* op = e.as(); + if (bound_.find(op) == bound_.end()) { + analyzer_.const_int_bound(e, &bound_); + } + ConstIntBound bound = bound_[op]; int64_t ubound = Downcast(max_value(DataType::Int(target_bits_)))->value; int64_t lbound = Downcast(min_value(DataType::Int(target_bits_)))->value; if (e.dtype().bits() <= target_bits_ || @@ -183,6 +187,8 @@ class DataTypeVisitor final : public StmtExprVisitor { int target_bits_; // the extent of vars to be rewritten std::unordered_map vextent_; + // the memorized bound generated by ConstIntBoundAnalyzer + std::unordered_map bound_; }; class DataTypeRewriter : public StmtExprMutator { @@ -394,7 +400,6 @@ namespace transform { Pass NarrowDataType() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - // TODO(@hzfan): should Target be Attr here, with target_bits inferred from it? IntImm target_bits = f->GetAttr("target_bits"); CHECK(target_bits.defined()) << "NarrowDataType: Require the target_bits"; From eb5e02a2702db613bb7a39a20be549772bef2a0e Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 2 Apr 2020 21:23:24 +0800 Subject: [PATCH 25/25] Fix sanity --- include/tvm/arith/analyzer.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 4a85bb320b72..1889e16fef66 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -117,6 +117,7 @@ class ConstIntBoundAnalyzer { /*! * \brief analyze the expr with the intermediate memorized to avoid redundant computation * \param expr The expression of interest. + * \param bound The lookup table to store the intermediate results * \return the result of the analysis. */ ConstIntBound operator()(const PrimExpr& expr,