-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR][PASS] dtype rewrite for indexing variables #5092
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do another round of review tomorrow. @tqchen could you also take a look?
@junrushao1994 @ZihengJiang @merrymercy @Hzfengsy would be great if you can also help to take a look |
8c7c303
to
0724b66
Compare
4e84153
to
178274c
Compare
c668c00
to
3ba40c2
Compare
bounds = te.schedule.InferBound(sch) | ||
stmt = te.schedule.ScheduleOps(sch, bounds) | ||
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False) | ||
stmt = tvm.tir.ir_pass.NarrowDataType(stmt, 32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be great if we can add more test cases:
- Please also consider use ir_builde to directly build loops
- Have testcase that narrows to i16
- Have a loop variable occurs in multiple expressions, one expr overflows, another does not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests added. test_slice
covers the last case you mentioned.
src/tir/pass/narrow_datatype.cc
Outdated
ConstIntBound bound = analyzer_.const_int_bound(e); | ||
int64_t ubound = Downcast<IntImm, PrimExpr>(max_value(DataType::Int(target_bits_)))->value; | ||
int64_t lbound = Downcast<IntImm, PrimExpr>(min_value(DataType::Int(target_bits_)))->value; | ||
if (e.dtype().bits() <= target_bits_ || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we rewrite lower bits into higher ones? cc @yzhliu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have an example? We previously reviewed some of the scenarios, not seeing needs doing so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To clarify, this code seems to indicate that we can rewrite lower bits into higher ones, and I think we do not need this behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think this pass aligns with what you said. It only narrows and never promotes. This code means lower bits fits into higher bits. For example, Consider e
is i.i64 <= j.i64
, a bool expression with only 1 bit. So e
fits into i32 and does not hinder narrowing i to i32.
https://github.com/apache/incubator-tvm/pull/5092/files#diff-98ae729cf00e30cff311ed80b4a25df9R129 ensures dtype promotion does not occcur.
src/tir/pass/narrow_datatype.cc
Outdated
if (vmap.find(op) == vmap.end()) { | ||
vmap[op] = op->dtype.with_bits(bits); | ||
} else { | ||
vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be great to add more comment here. e.g. We are taking maximum bits for all the possible Exprs that a Var occurs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the example. Comments added.
src/tir/pass/narrow_datatype.cc
Outdated
|
||
void VisitExpr_(const VarNode* op) { | ||
if (vextent_.find(op) != vextent_.end()) { | ||
int bits = std::min(vextent_[op].bits(), bits_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add more comments about the algorithm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added. Here we ensure that datatype is not promoted.
src/tir/pass/narrow_datatype.cc
Outdated
using arith::IRMutatorWithAnalyzer; | ||
using arith::ConstIntBound; | ||
|
||
class DataTypeVisitor final : public StmtExprVisitor { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please document:
- input, what would vmap eventually store
- The general algorithm we use(e.g. we propagate the bits backwards into vmap)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Documented.
src/tir/pass/narrow_datatype.cc
Outdated
void VisitExpr(const PrimExpr& e) { | ||
if (e.dtype().is_int()) { | ||
int bits = max_bits_; | ||
ConstIntBound bound = analyzer_.const_int_bound(e); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NOTE: the constant int bound here is not necessarily the most efficient for deep nested expressions.
As we are recursively calling const int bound for all sub-expressions of e as well(when we recursively visit). Perhaps we want to add a const int bound with memoization option that allows the analyzer to pass a memo(of each subexpr to the const int bound).
We could add it as a TODO item, or directly do it in this PR, but would be great to be resolved in next few weeks cc @yzhliu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about a state flag like with_memo
in class ConstIntBoundAnalyzer
? Once it is set, we first look up in the table when a new Expr comes. We can have an unordered_map like unordered_map<PrimExpr*, Entry>
in ConstIntBoundAnalyzer
to achieve this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
memo can have unintended consequences if the vars can be bound to different context dependent info(e.g. if (x<10) {x+1; } else x;
x<10
is only effective in the then branch.
I would say perhaps we could have another API to pass in a unordered map, and ask the analyzer to record every intermediate steps into the map
Some more comments, mainly wrt to clarity of the code, test coverage and efficiency concerns. Thanks for bringing in the PR. given that this is critical to a lot of the codebase, let us try to https://docs.tvm.ai/contribute/code_review.html#hold-the-highest-standard :) Let us work to polish it to the best state. Thanks @hzfan for good work so far |
@spectrometerHBH @FrozenGene @Hzfengsy please also help to take a look |
if (v1 != nullptr) { | ||
// integers that do not fit in int32_t are treated as symbolic, | ||
// as it's impossible to unroll such large loops | ||
if (v1 != nullptr && v1->value <= std::numeric_limits<int>::max()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we should use int32_t
here rather than int
. I'm not sure, but just worry about int
will represent different types (int16_t
, int32_t
or int64_t
) on different systems and devices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My motivation here is to prevent overflow in the next line (which uses int):
value = static_cast<int>(v1->value);
IMO it might be fine to use int
here, as it's consistent with other parts of the pass. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/tir/pass/narrow_datatype.cc
Outdated
void VisitExpr(const PrimExpr& e) { | ||
if (e.dtype().is_int()) { | ||
int bits = max_bits_; | ||
ConstIntBound bound = analyzer_.const_int_bound(e); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
56ee7a8
to
e960449
Compare
Changes:
Some background:
https://discuss.tvm.ai/t/rfc-support-for-large-tensors/5643
Take the following as an example:
m, n = te.var(’m’, dtype=‘int64’), te.var(’n’, dtype=‘int64’)
yieldsm, n = tvm.tir.const(2, dtype="int64"), tvm.tir.const(2, dtype="int64")
yields@yzhliu Could you review?