diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index cb28b8187222..225d22599554 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1018,7 +1018,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) { - CHECK(!var_map_.count(op->var.get())); + auto it = let_binding_.find(op->var); + if (it != let_binding_.end()) { + CHECK(deep_equal_(it->second->value, op->value)) + << "Let cannot bind the same var to two different values"; + } else { + let_binding_[op->var] = op; + } var_map_[op->var.get()] = MakeValue(op->value); analyzer_->Bind(op->var, op->value); return MakeValue(op->body); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index e20c8e197969..ce5baba7ea19 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -321,6 +322,10 @@ class CodeGenLLVM : public ExprFunctor, std::unordered_set alias_var_set_; // set of volatile buffer. std::unordered_set volatile_buf_; + // deep comparison of PrimExpr + ExprDeepEqual deep_equal_; + // binding of let variables. Enables duplicate var defs that map to same value + std::unordered_map let_binding_; // Cache potential common path ops to slightly improve lookup time. // global symbol table. OpAttrMap op_attr_global_symbol_ = Op::GetAttrMap("TGlobalSymbol"); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 153089204576..3e6838ce1b0b 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -761,8 +761,14 @@ void CodeGenC::VisitStmt_(const StoreNode* op) { } void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*) + auto it = let_binding_.find(op->var); + if (it != let_binding_.end()) { + CHECK(deep_equal_(it->second->value, op->value)) + << "Let cannot bind the same var to two different values"; + } else { + let_binding_[op->var] = op; + } std::string value = PrintExpr(op->value); - CHECK(!var_idmap_.count(op->var.get())); var_idmap_[op->var.get()] = value; os << PrintExpr(op->body); } diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 87a4a2944130..c1b566c064a4 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -269,6 +270,10 @@ class CodeGenC : public ExprFunctor, bool print_ssa_form_{false}; /*! \brief set of volatile buf access */ std::unordered_set volatile_buf_; + // deep comparison of PrimExpr + ExprDeepEqual deep_equal_; + // binding of let variables. Enables duplicate var defs that map to same value + std::unordered_map let_binding_; }; } // namespace codegen diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 7ff0c5520e24..2a67d953f960 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -230,7 +230,13 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const SelectNode* op) { } spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) { - CHECK(!var_map_.count(op->var.get())); + auto it = let_binding_.find(op->var); + if (it != let_binding_.end()) { + CHECK(deep_equal_(it->second->value, op->value)) + << "Let cannot bind the same var to two different values"; + } else { + let_binding_[op->var] = op; + } var_map_[op->var.get()] = MakeValue(op->value); analyzer_->Bind(op->var, op->value); return MakeValue(op->body); diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index a8af29a194d5..9bf81095f066 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -25,6 +25,7 @@ #define TVM_TARGET_SPIRV_CODEGEN_SPIRV_H_ #include +#include #include #include #include @@ -140,6 +141,10 @@ class CodeGenSPIRV : public ExprFunctor, std::unordered_map var_map_; // The analyzer. std::unique_ptr analyzer_; + // deep comparison of PrimExpr + ExprDeepEqual deep_equal_; + // binding of let variables. Enables duplicate var defs that map to same value + std::unordered_map let_binding_; }; } // namespace codegen diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 1c529d86523e..f3fe945b8d4c 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -112,14 +112,22 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } } } else { - // uncommon case - DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor"; - // b >= 0 => (rmod >=0 ? rdiv : rdiv - 1) - // b < 0 => (rmod <= 0 ? rdiv : rdiv - 1) - PrimExpr rdiv = truncdiv(op->a, op->b); - PrimExpr rmod = truncmod(op->a, op->b); - return tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv, - rdiv - make_const(dtype, 1)); + if (dtype.is_float()) { + // floor(a / b) + return VisitExpr_(tvm::floor(op->a / op->b).as()); + } else { + // uncommon case + DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor"; + auto rmod = tir::Var("rmod", dtype); + auto rdiv = tir::Var("rdiv", dtype); + // b >= 0 => (rmod >=0 ? rdiv : rdiv - 1) + // b < 0 => (rmod <= 0 ? rdiv : rdiv - 1) + PrimExpr let_rdiv = + tir::Let(rdiv, truncdiv(op->a, op->b), + tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv, + rdiv - make_const(dtype, 1))); + return Let(rmod, truncmod(op->a, op->b), let_rdiv); + } } } @@ -158,14 +166,21 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } } } else { - // uncommon case - DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident"; - PrimExpr rmod = truncmod(op->a, op->b); - // b > 0 && rmod >= 0 -> rmod - // b > 0 && rmod < 0 -> rmod + b - // b < 0 && rmod < 0 -> rmod - // b < 0 && rmod > 0 -> rmod + b - return tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, rmod + op->b); + if (dtype.is_float()) { + // a - floor(a / b) * b + return op->a - (VisitExpr_(tvm::floor(op->a / op->b).as()) * op->b); + } else { + // uncommon case + DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident"; + auto rmod = tir::Var("rmod", dtype); + // b > 0 && rmod >= 0 -> rmod + // b > 0 && rmod < 0 -> rmod + b + // b < 0 && rmod < 0 -> rmod + // b < 0 && rmod > 0 -> rmod + b + return Let( + rmod, truncmod(op->a, op->b), + Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, rmod + op->b)); + } } } diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 169ac1401445..d5b51cbf2236 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -99,14 +99,28 @@ class VarUseDefAnalysis : public StmtExprMutator { } PrimExpr VisitExpr_(const LetNode* op) final { - this->HandleDef(op->var.get()); + // Weaker SSA condition + // A single var can be binded in multiple lets + // but they have to bind to the same value. + // This is used to allow cases when we reuse a single let + // expression to construct a nested expr. + // (let x = 1 in x + 1) * (let x = 1 in x + 1) + auto it = let_binding_.find(op->var); + PrimExpr value = this->VisitExpr(op->value); + if (it != let_binding_.end()) { + CHECK(deep_equal_(it->second->value, value)) + << "Let cannot bind the same var to two different values"; + return GetRef(it->second); + } else { + this->HandleDef(op->var.get()); + let_binding_[op->var] = op; + } PrimExpr body = this->VisitExpr(op->body); // eliminate unreferenced let if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState && simplify_let_) { return body; } else { - PrimExpr value = this->VisitExpr(op->value); if (body.same_as(op->body) && value.same_as(op->value)) { return GetRef(op); } else { @@ -157,6 +171,10 @@ class VarUseDefAnalysis : public StmtExprMutator { Array thread_extent_; std::unordered_map use_count_; std::unordered_map def_count_; + + private: + ExprDeepEqual deep_equal_; + std::unordered_map let_binding_; }; Array UndefinedVars(const Stmt& stmt, const Array& args) { diff --git a/topi/tests/python/test_topi_broadcast.py b/topi/tests/python/test_topi_broadcast.py index 27b66e04e394..f3e0300a2d81 100644 --- a/topi/tests/python/test_topi_broadcast.py +++ b/topi/tests/python/test_topi_broadcast.py @@ -90,19 +90,6 @@ def check_device(device): rhs_npy, rhs_nd = gen_operand(rhs_shape, rhs_min, rhs_max, ctx) out_npy = fnumpy(lhs_npy, rhs_npy) - if fnumpy == np.floor_divide: - # avoid check too close to X.5 and X.0 - # FIXME: floor_divide(94.90735, 0.6731018) behaves as floor(div(94.90735, 0.6731018)) - # However the result is somehow incorrect - need to further investigate. - # And looks like numpy's floor_div(a,b) is implemented different from floor(div(a,b)) - mask = np.logical_or(np.abs(np.abs(np.fmod(lhs_npy / rhs_npy, 1)) - 0.5) < 1e-6, - np.abs(np.fmod(lhs_npy / rhs_npy, 1)) < 1e-6) - if mask.any(): - lhs_npy = lhs_npy + mask * 1e-3 * rhs_npy - lhs_npy = lhs_npy.astype(dtype) - lhs_nd = tvm.nd.array(lhs_npy, ctx) if lhs_shape is not None else lhs_npy.item() - out_npy = fnumpy(lhs_npy, rhs_npy) - out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx) foo(lhs_nd, rhs_nd, out_nd) tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4) @@ -151,12 +138,14 @@ def test_divide(): (2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001) def test_floor_divide(): + def _canonical_floor_div(a,b): + return np.floor(a / b) verify_broadcast_binary_ele( - None, (10,), topi.floor_divide, np.floor_divide, rhs_min=0.0001) + None, (10,), topi.floor_divide, _canonical_floor_div, rhs_min=0.0001) verify_broadcast_binary_ele( - (), None, topi.floor_divide, np.floor_divide, rhs_min=0.0001) + (), None, topi.floor_divide, _canonical_floor_div, rhs_min=0.0001) verify_broadcast_binary_ele( - (2, 3, 64, 32), (64, 32), topi.floor_divide, np.floor_divide, rhs_min=0.0001) + (2, 3, 64, 32), (64, 32), topi.floor_divide, _canonical_floor_div, rhs_min=0.0001) def test_maximum_minmum(): verify_broadcast_binary_ele( @@ -175,10 +164,12 @@ def test_mod(): (1, 2, 2), (2,), topi.mod, np.mod, lhs_min=0.001, rhs_min=1, dtype="int32") def test_floor_mod(): + def _canonical_floor_mod(a,b): + return a - np.floor(a / b) * b verify_broadcast_binary_ele( - (1, 2, 2), (2,), topi.floor_mod, np.fmod, lhs_min=0.001, rhs_min=1, dtype="int32") + (1, 2, 2), (2,), topi.floor_mod, _canonical_floor_mod, lhs_min=0.001, rhs_min=1, dtype="int32") verify_broadcast_binary_ele( - (3, 4, 5), (3, 4, 5), topi.floor_mod, np.fmod, lhs_min=0.001, rhs_min=1, dtype="float32") + (3, 4, 5), (3, 4, 5), topi.floor_mod, _canonical_floor_mod, lhs_min=0.001, rhs_min=1, dtype="float32") def test_cmp(): # explicit specify the output type