-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR] Add TIR While node #7425
[TIR] Add TIR While node #7425
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, very nice addition!
Thanks @masahi , before we merge it in. it would be really awesome to go through the current list of passes and check if special handling of while is needed (so we won't bring in new bugs because the mix). Some of the example passes could include (I would at least check passes that need special IfThenElse handling) For example, I can see the need to update following pass:
|
also cc @zxybazh please help to review this PR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! It looks good to me :-) Surprisingly it doesn't need to change any passes besides storage_rewrite :-)
CC @spectrometerHBH: we might want to have it supported in TensorIR too, either like a syntactic sugar to opaque binding or other ways |
@@ -109,6 +110,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> { | |||
IR_STMT_FUNCTOR_DISPATCH(AttrStmtNode); | |||
IR_STMT_FUNCTOR_DISPATCH(IfThenElseNode); | |||
IR_STMT_FUNCTOR_DISPATCH(ForNode); | |||
IR_STMT_FUNCTOR_DISPATCH(WhileNode); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need checks through the current passes, per my comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @masahi! Looks good to me.
@tqchen @junrushao1994 @vinx13 I went through the passes and here is my summary:
|
@tqchen Can you have a look? |
I left a comment for inject virtual thread, @junrushao1994 @ZihengJiang @vinx13 would be great if you can also help check the StorageAccessVisitor |
I've checked |
@vinx13 Ok, For tvm/src/tir/transforms/storage_rewrite.cc Lines 241 to 251 in 7340c02
But I don't see how we should update tvm/src/tir/transforms/storage_rewrite.cc Lines 757 to 773 in 7340c02
|
Thanks @masahi , it would also be great for you to spend a bit more time to look into these passes :) It certainly takes more time, but we will also have more experts in TIR passes :) Please also consider to add a test case to the passes that need while handling |
@masahi For tvm/src/tir/transforms/storage_rewrite.cc Lines 440 to 452 in 7340c02
|
ok, to me it's not obvious what it is doing, time for another deep dive... |
@tqchen @vinx13 @junrushao1994 Does the behavior of In the following IR, "A" and "B" buffers, which are allocated in
In the following IR, all buffers, including the one allocated inside
|
@vinx13 can you please take another look at the PR and manage? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @masahi ! the change has addressed my previous comments. Please add testcases to transforms that touches requires special While handling to cover these passes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! Thanks @masahi
@tqchen @junrushao1994 @vinx13 @ZihengJiang @zxybazh I came to a conclusion that While node doesn't need a special handling in The first observation is that even if I remove all If we look at the visitor for tvm/src/tir/transforms/storage_rewrite.cc Lines 440 to 452 in 7340c02
it only does something special when attach_map_ has an entry for this node. Here comes the second observation: the only case whereattach_map_ can have an entry for ForNode is if this ForNode is a parallel for loop, due to these lines: tvm/src/tir/transforms/storage_rewrite.cc Lines 766 to 772 in 7340c02
Together, these two handler for tvm/src/tir/transforms/storage_rewrite.cc Line 447 in 7340c02
test_parallel_alloc() . For other kinds of For loop, a merged allocation is placed at the global scope, see tvm/src/tir/transforms/storage_rewrite.cc Lines 457 to 461 in 7340c02
Since
storage_rewrite .
I think I nailed it, thoughts? |
@masahi You are right, thanks for looking into this |
That makes sense to me. Thanks for diving deep into this issue! |
cc @tqchen please take a look |
@masahi you are right that the MakeAttach is only needed for parallel for loop, where we can nolonger lift the memory to the outside(otherwise the memory won't be thread local) |
@junrushao1994 @vinx13 please help to manage the PR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One minor comment
@junrushao1994 @vinx13 @tqchen ready to merge...!! |
Thanks everyone @masahi @tqchen @junrushao1994 @giuseros @zxybazh @ZihengJiang |
Really awesome work!!! |
Thank you very much for the reviews!! |
* add while node * update visitors * binary search lowering works * llvm codegen working * cuda codegen working * nms updated to use while loop * add missing upper bound check too * add mandelbrot test * add gpu mandel commit ee2363b Author: Masahiro Masuda <[email protected]> Date: Fri Jan 29 11:44:02 2021 +0900 enable extern lib offload for nvptx * rename test * run black * add doc * add collatz test * add while + vectorize test * simplify bin search * Add special case visit method to storage_access.cc * disallow while loop inside vectorized loop * disallow trivial condition since we do not have break * error out in CoprocSync for now * error out LiftAttrScope for now * add placeholder to inject_vpthread * refactor to use MakeAttach * handle WhileNode in InplaceOpVerifier * error out in InjectVirtualThread * try handle WhileNode in StoragePlanRewriter * remove WhileNode visitor from storage rewrite * add while loop storage rewrite test * update tests * move test_vectorize_while_fail to test_tir_transform_vectorize.py
* add while node * update visitors * binary search lowering works * llvm codegen working * cuda codegen working * nms updated to use while loop * add missing upper bound check too * add mandelbrot test * add gpu mandel commit ee2363b Author: Masahiro Masuda <[email protected]> Date: Fri Jan 29 11:44:02 2021 +0900 enable extern lib offload for nvptx * rename test * run black * add doc * add collatz test * add while + vectorize test * simplify bin search * Add special case visit method to storage_access.cc * disallow while loop inside vectorized loop * disallow trivial condition since we do not have break * error out in CoprocSync for now * error out LiftAttrScope for now * add placeholder to inject_vpthread * refactor to use MakeAttach * handle WhileNode in InplaceOpVerifier * error out in InjectVirtualThread * try handle WhileNode in StoragePlanRewriter * remove WhileNode visitor from storage rewrite * add while loop storage rewrite test * update tests * move test_vectorize_while_fail to test_tir_transform_vectorize.py
This is an implementation of TIR While node as discussed in RFC https://discuss.tvm.apache.org/t/rfc-add-while-loop-node-to-tir/9028. It supercedes my earlier attempt in #7385.
The PR consists of
storage_rewrite.cc
, everything else uses the default visitor)Hybrid script support etc are left for future work.
Now we can write binary search succinctly as follows:
As another nice use of while loop, I added a test that draws a useless mandelbrot set 🙂
@tqchen @junrushao1994 @vinx13 @mbrookhart @zhiics @kevinthesun @anijain2305 @trevor-m