Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][PASS] dtype rewrite for indexing variables #5092

Merged
merged 25 commits into from
Apr 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,15 @@ class ConstIntBoundAnalyzer {
*/
ConstIntBound operator()(const PrimExpr& expr);

/*!
* \brief analyze the expr with the intermediate memorized to avoid redundant computation
* \param expr The expression of interest.
* \param bound The lookup table to store the intermediate results
* \return the result of the analysis.
*/
ConstIntBound operator()(const PrimExpr& expr,
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound);

/*!
* \brief Update constant int bound information of var.
*
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,15 @@ Stmt DecorateDeviceScope(Stmt stmt);
*/
Stmt HoistIfThenElse(Stmt stmt);

/*!
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
* \note Run this pass after StorageFlatten.
* \param stmt The stmt to do datatype rewrite
* \param target_bits the bit of target datatype
* \return Transformed stmt.
*/
Stmt NarrowDataType(Stmt stmt, int target_bits);

/*!
* \brief Make an user callable API LoweredFunc.
*
Expand Down
10 changes: 10 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ TVM_DLL Pass LowerDeviceStorageAccessInfo();
*/
TVM_DLL Pass LowerWarpMemory();


/*!
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
*
* \note Run this pass after StorageFlatten.
*
* \return The pass.
*/
TVM_DLL Pass NarrowDataType();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
1 change: 1 addition & 0 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def lower(sch,
# Phase 1
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.NarrowDataType(stmt, 32)
stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1:
stmt = f(stmt)
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/tir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ def __init__(self, dom, var, iter_type, thread_tag=""):
raise TypeError("dom need to be Range")

name = var if var is not None else "iter"
var = Var(name, dtype="int32") if not isinstance(var, Var) else var
dtype = "int32" if dom is None else dom.extent.dtype
var = Var(name, dtype=dtype) if not isinstance(var, Var) else var
self.__init_handle_by_constructor__(
_ffi_api.IterVar, dom, var, iter_type, thread_tag)

Expand Down
6 changes: 4 additions & 2 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def dtype(self):
def __getitem__(self, index):
t = DataType(self._content_type)
if t.lanes > 1:
index = _expr.Ramp(index * t.lanes, 1, t.lanes)
base = index * t.lanes
index = _expr.Ramp(base, const(1, base.dtype), t.lanes)
return _expr.Load(self._content_type, self._buffer_var, index)

def __setitem__(self, index, value):
Expand All @@ -87,7 +88,8 @@ def __setitem__(self, index, value):
value.dtype, self._content_type))
t = DataType(self._content_type)
if t.lanes > 1:
index = _expr.Ramp(index * t.lanes, 1, t.lanes)
base = index * t.lanes
index = _expr.Ramp(base, const(1, base.dtype), t.lanes)
self._builder.emit(_stmt.Store(self._buffer_var, value, index))


Expand Down
15 changes: 15 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,18 @@ def LowerWarpMemory():
The result pass
"""
return _ffi_api.LowerWarpMemory()


def NarrowDataType():
"""Narrow down PrimExpr datatype in stmt to target_bits.

Returns
-------
fpass : tvm.ir.transform.Pass
The result pass

Note
----
Run this pass after StorageFlatten.
"""
return _ffi_api.NarrowDataType()
32 changes: 32 additions & 0 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,30 @@ class ConstIntBoundAnalyzer::Impl :
res = Intersect(res, info.bound);
}
}
if (bound_) {
const PrimExprNode* op = expr.as<PrimExprNode>();
auto val = bound_->find(op);
if (val != bound_->end()) {
CHECK(val->second->min_value == res.min_value &&
val->second->max_value == res.max_value)
<< "Detected bound for " << expr
<< "conflicts with memorization";
}
(*bound_)[op] = ConstIntBound(res.min_value, res.max_value);
}
return res;
}

Entry VisitExpr_(const RampNode* op) final {
// op = {base + i * stride | 0 <= i < lanes}
// Entry(op) = Union(Entry(base + i * stride) | 0 <= i < lanes)
// Note that `base + i * stride` is linear w.r.t. `i`
// Entry(op) = Union(Entry(base + i * stride) | i = 0, i = lanes-1)
Entry a = VisitExpr(op->base);
Entry b = VisitExpr(op->base + (op->lanes - 1) * op->stride);
return Union(a, b);
}

Entry VisitExpr_(const CastNode* op) final {
Entry a = VisitExpr(op->value);
Entry b = Everything(op->dtype);
Expand Down Expand Up @@ -339,10 +360,13 @@ class ConstIntBoundAnalyzer::Impl :
}

private:
friend class ConstIntBoundAnalyzer;
// internal variable map
std::unordered_map<Var, Entry, ObjectHash, ObjectEqual> var_map_;
// additional bound info
std::vector<BoundInfo> additional_info_;
// look up table for memorization
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound_{nullptr};
// constants: the limit value means umlimited
// NOTE: kNegInf/kPosInf are used to represent infinity.
static const constexpr int64_t kNegInf = ConstIntBound::kNegInf;
Expand Down Expand Up @@ -535,6 +559,14 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) {
return ConstIntBound(ret.min_value, ret.max_value);
}

ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr,
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound) {
impl_->bound_ = bound;
Entry ret = impl_->VisitExpr(expr);
impl_->bound_ = nullptr;
return ConstIntBound(ret.min_value, ret.max_value);
}

void ConstIntBoundAnalyzer::Update(const Var& var,
const ConstIntBound& info,
bool override) {
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
PrimExpr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent);
CreateSerialFor(MakeValue(begin),
MakeValue(end),
ConstInt32(1),
llvm::ConstantInt::getSigned(GetLLVMType(end), 1),
op->loop_var,
op->body);
}
Expand Down
3 changes: 2 additions & 1 deletion src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,8 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) {
CHECK(op->for_type == ForType::Serial);
}
CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
ConstInt32(1), op->loop_var, op->body);
llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1),
op->loop_var, op->body);
}


Expand Down
2 changes: 1 addition & 1 deletion src/tir/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ Buffer BufferNode::make(Var data,
n->buffer_type = buffer_type;
if (n->buffer_type == kAutoBroadcast && n->shape.size() > 0 && n->strides.empty()) {
for (size_t i = 0; i < n->shape.size(); ++i) {
n->strides.push_back(Var("stride"));
n->strides.push_back(Var("stride", n->shape[i].dtype()));
}
}
return Buffer(n);
Expand Down
1 change: 1 addition & 0 deletions src/tir/pass/ffi_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,5 +165,6 @@ REGISTER_PASS(InstrumentBoundCheckers);
REGISTER_PASS(VerifyCompactBuffer);
REGISTER_PASS(HoistIfThenElse);
REGISTER_PASS(InferFragment)
REGISTER_PASS(NarrowDataType);
} // namespace tir
} // namespace tvm
2 changes: 1 addition & 1 deletion src/tir/pass/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt b
// If the loop extent is 1, do not create the loop anymore
return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}});
} else {
return ForNode::make(for_node->loop_var, 0, extent,
return ForNode::make(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent,
for_node->for_type, for_node->device_api, body);
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/tir/pass/unroll_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ class LoopUnroller : public StmtExprMutator {
PrimExpr extent = tir::Simplify(op->extent);
const IntImmNode *v1 = extent.as<IntImmNode>();
int value = -1;
if (v1 != nullptr) {
// integers that do not fit in int32_t are treated as symbolic,
// as it's impossible to unroll such large loops
hzfan marked this conversation as resolved.
Show resolved Hide resolved
if (v1 != nullptr && v1->value <= std::numeric_limits<int>::max()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should use int32_t here rather than int. I'm not sure, but just worry about int will represent different types (int16_t, int32_t or int64_t) on different systems and devices.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My motivation here is to prevent overflow in the next line (which uses int):

value = static_cast<int>(v1->value);

IMO it might be fine to use int here, as it's consistent with other parts of the pass. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine

value = static_cast<int>(v1->value);
}
return value;
Expand Down
Loading