Skip to content

Commit

Permalink
[TIR][Bugfix] Improved massive build times caused by tir.floormod and…
Browse files Browse the repository at this point in the history
… tir.floordiv. Fixed Topi testcase. (apache#5666)

* Improved uncommon case of floormod and floordiv. Removed dependence on np floor_div and fmod.

* Fixed clang-format complaints

* Streamlined floormod and floordiv lowering logic

* Improved build times by expressing int64 case of tir FloorMod and FloorDiv using let nodes

* Updated use-def analysis and llvm codegen to support duplicated letnodes.

* Corrected misuse of var_map_ in llvm codegen

* Updated backends that support LetNode

* Changed floormod and div lowering logic to avoid using FP on systems that don't support it.

* Fixed formatting

Co-authored-by: pankratz <[email protected]>
  • Loading branch information
2 people authored and Trevor Morris committed Sep 2, 2020
1 parent bdff511 commit 11438ab
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 39 deletions.
8 changes: 7 additions & 1 deletion src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) {
}

llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) {
CHECK(!var_map_.count(op->var.get()));
auto it = let_binding_.find(op->var);
if (it != let_binding_.end()) {
CHECK(deep_equal_(it->second->value, op->value))
<< "Let cannot bind the same var to two different values";
} else {
let_binding_[op->var] = op;
}
var_map_[op->var.get()] = MakeValue(op->value);
analyzer_->Bind(op->var, op->value);
return MakeValue(op->body);
Expand Down
5 changes: 5 additions & 0 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/ir/module.h>
#include <tvm/runtime/container.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
Expand Down Expand Up @@ -321,6 +322,10 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
std::unordered_set<const VarNode*> alias_var_set_;
// set of volatile buffer.
std::unordered_set<const VarNode*> volatile_buf_;
// deep comparison of PrimExpr
ExprDeepEqual deep_equal_;
// binding of let variables. Enables duplicate var defs that map to same value
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
// Cache potential common path ops to slightly improve lookup time.
// global symbol table.
OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
Expand Down
8 changes: 7 additions & 1 deletion src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -761,8 +761,14 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
}

void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*)
auto it = let_binding_.find(op->var);
if (it != let_binding_.end()) {
CHECK(deep_equal_(it->second->value, op->value))
<< "Let cannot bind the same var to two different values";
} else {
let_binding_[op->var] = op;
}
std::string value = PrintExpr(op->value);
CHECK(!var_idmap_.count(op->var.get()));
var_idmap_[op->var.get()] = value;
os << PrintExpr(op->body);
}
Expand Down
5 changes: 5 additions & 0 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/ir/op.h>
#include <tvm/runtime/container.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
Expand Down Expand Up @@ -269,6 +270,10 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
bool print_ssa_form_{false};
/*! \brief set of volatile buf access */
std::unordered_set<const VarNode*> volatile_buf_;
// deep comparison of PrimExpr
ExprDeepEqual deep_equal_;
// binding of let variables. Enables duplicate var defs that map to same value
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
};

} // namespace codegen
Expand Down
8 changes: 7 additions & 1 deletion src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,13 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const SelectNode* op) {
}

spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) {
CHECK(!var_map_.count(op->var.get()));
auto it = let_binding_.find(op->var);
if (it != let_binding_.end()) {
CHECK(deep_equal_(it->second->value, op->value))
<< "Let cannot bind the same var to two different values";
} else {
let_binding_[op->var] = op;
}
var_map_[op->var.get()] = MakeValue(op->value);
analyzer_->Bind(op->var, op->value);
return MakeValue(op->body);
Expand Down
5 changes: 5 additions & 0 deletions src/target/spirv/codegen_spirv.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_TARGET_SPIRV_CODEGEN_SPIRV_H_

#include <tvm/arith/analyzer.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
Expand Down Expand Up @@ -140,6 +141,10 @@ class CodeGenSPIRV : public ExprFunctor<spirv::Value(const PrimExpr&)>,
std::unordered_map<const VarNode*, spirv::Value> var_map_;
// The analyzer.
std::unique_ptr<arith::Analyzer> analyzer_;
// deep comparison of PrimExpr
ExprDeepEqual deep_equal_;
// binding of let variables. Enables duplicate var defs that map to same value
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
};

} // namespace codegen
Expand Down
47 changes: 31 additions & 16 deletions src/tir/transforms/lower_intrin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,22 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
}
}
} else {
// uncommon case
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor";
// b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
// b < 0 => (rmod <= 0 ? rdiv : rdiv - 1)
PrimExpr rdiv = truncdiv(op->a, op->b);
PrimExpr rmod = truncmod(op->a, op->b);
return tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv,
rdiv - make_const(dtype, 1));
if (dtype.is_float()) {
// floor(a / b)
return VisitExpr_(tvm::floor(op->a / op->b).as<CallNode>());
} else {
// uncommon case
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor";
auto rmod = tir::Var("rmod", dtype);
auto rdiv = tir::Var("rdiv", dtype);
// b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
// b < 0 => (rmod <= 0 ? rdiv : rdiv - 1)
PrimExpr let_rdiv =
tir::Let(rdiv, truncdiv(op->a, op->b),
tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv,
rdiv - make_const(dtype, 1)));
return Let(rmod, truncmod(op->a, op->b), let_rdiv);
}
}
}

Expand Down Expand Up @@ -158,14 +166,21 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
}
}
} else {
// uncommon case
DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident";
PrimExpr rmod = truncmod(op->a, op->b);
// b > 0 && rmod >= 0 -> rmod
// b > 0 && rmod < 0 -> rmod + b
// b < 0 && rmod < 0 -> rmod
// b < 0 && rmod > 0 -> rmod + b
return tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, rmod + op->b);
if (dtype.is_float()) {
// a - floor(a / b) * b
return op->a - (VisitExpr_(tvm::floor(op->a / op->b).as<CallNode>()) * op->b);
} else {
// uncommon case
DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident";
auto rmod = tir::Var("rmod", dtype);
// b > 0 && rmod >= 0 -> rmod
// b > 0 && rmod < 0 -> rmod + b
// b < 0 && rmod < 0 -> rmod
// b < 0 && rmod > 0 -> rmod + b
return Let(
rmod, truncmod(op->a, op->b),
Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, rmod + op->b));
}
}
}

Expand Down
22 changes: 20 additions & 2 deletions src/tir/transforms/split_host_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,28 @@ class VarUseDefAnalysis : public StmtExprMutator {
}

PrimExpr VisitExpr_(const LetNode* op) final {
this->HandleDef(op->var.get());
// Weaker SSA condition
// A single var can be binded in multiple lets
// but they have to bind to the same value.
// This is used to allow cases when we reuse a single let
// expression to construct a nested expr.
// (let x = 1 in x + 1) * (let x = 1 in x + 1)
auto it = let_binding_.find(op->var);
PrimExpr value = this->VisitExpr(op->value);
if (it != let_binding_.end()) {
CHECK(deep_equal_(it->second->value, value))
<< "Let cannot bind the same var to two different values";
return GetRef<PrimExpr>(it->second);
} else {
this->HandleDef(op->var.get());
let_binding_[op->var] = op;
}
PrimExpr body = this->VisitExpr(op->body);
// eliminate unreferenced let
if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState &&
simplify_let_) {
return body;
} else {
PrimExpr value = this->VisitExpr(op->value);
if (body.same_as(op->body) && value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
} else {
Expand Down Expand Up @@ -157,6 +171,10 @@ class VarUseDefAnalysis : public StmtExprMutator {
Array<PrimExpr> thread_extent_;
std::unordered_map<const VarNode*, int> use_count_;
std::unordered_map<const VarNode*, int> def_count_;

private:
ExprDeepEqual deep_equal_;
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
};

Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
Expand Down
27 changes: 9 additions & 18 deletions topi/tests/python/test_topi_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,6 @@ def check_device(device):
rhs_npy, rhs_nd = gen_operand(rhs_shape, rhs_min, rhs_max, ctx)
out_npy = fnumpy(lhs_npy, rhs_npy)

if fnumpy == np.floor_divide:
# avoid check too close to X.5 and X.0
# FIXME: floor_divide(94.90735, 0.6731018) behaves as floor(div(94.90735, 0.6731018))
# However the result is somehow incorrect - need to further investigate.
# And looks like numpy's floor_div(a,b) is implemented different from floor(div(a,b))
mask = np.logical_or(np.abs(np.abs(np.fmod(lhs_npy / rhs_npy, 1)) - 0.5) < 1e-6,
np.abs(np.fmod(lhs_npy / rhs_npy, 1)) < 1e-6)
if mask.any():
lhs_npy = lhs_npy + mask * 1e-3 * rhs_npy
lhs_npy = lhs_npy.astype(dtype)
lhs_nd = tvm.nd.array(lhs_npy, ctx) if lhs_shape is not None else lhs_npy.item()
out_npy = fnumpy(lhs_npy, rhs_npy)

out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
foo(lhs_nd, rhs_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
Expand Down Expand Up @@ -151,12 +138,14 @@ def test_divide():
(2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001)

def test_floor_divide():
def _canonical_floor_div(a,b):
return np.floor(a / b)
verify_broadcast_binary_ele(
None, (10,), topi.floor_divide, np.floor_divide, rhs_min=0.0001)
None, (10,), topi.floor_divide, _canonical_floor_div, rhs_min=0.0001)
verify_broadcast_binary_ele(
(), None, topi.floor_divide, np.floor_divide, rhs_min=0.0001)
(), None, topi.floor_divide, _canonical_floor_div, rhs_min=0.0001)
verify_broadcast_binary_ele(
(2, 3, 64, 32), (64, 32), topi.floor_divide, np.floor_divide, rhs_min=0.0001)
(2, 3, 64, 32), (64, 32), topi.floor_divide, _canonical_floor_div, rhs_min=0.0001)

def test_maximum_minmum():
verify_broadcast_binary_ele(
Expand All @@ -175,10 +164,12 @@ def test_mod():
(1, 2, 2), (2,), topi.mod, np.mod, lhs_min=0.001, rhs_min=1, dtype="int32")

def test_floor_mod():
def _canonical_floor_mod(a,b):
return a - np.floor(a / b) * b
verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.floor_mod, np.fmod, lhs_min=0.001, rhs_min=1, dtype="int32")
(1, 2, 2), (2,), topi.floor_mod, _canonical_floor_mod, lhs_min=0.001, rhs_min=1, dtype="int32")
verify_broadcast_binary_ele(
(3, 4, 5), (3, 4, 5), topi.floor_mod, np.fmod, lhs_min=0.001, rhs_min=1, dtype="float32")
(3, 4, 5), (3, 4, 5), topi.floor_mod, _canonical_floor_mod, lhs_min=0.001, rhs_min=1, dtype="float32")

def test_cmp():
# explicit specify the output type
Expand Down

0 comments on commit 11438ab

Please sign in to comment.