-
Notifications
You must be signed in to change notification settings - Fork 64
[torch_tvm] Support Lowering to TVM even if node cannot be fused #122
base: master
Are you sure you want to change the base?
Conversation
torch_tvm/fusion_pass.cpp
Outdated
REQ((canHandle(consumer, aliasDb) || consumer->kind() == getTVMSymbol())); | ||
|
||
// if producer cannot be converted, check if consumer can be lowered to TVM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should do this outside of this function. tryMerge
is supposed to merge two nodes into a lowerable node. Iideally we start with a seed outside. Maybe
if (canLowerSeed(consumer
If seed cannnot be lowered, we never come to tryMerge
?
Also @bwasti, maybe you should chime in here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think we should put this kind of logic in parallel to tryMerge instead of inside. I actually think we should have some kind of flag to control minimal number of ops in order to create a fusion group. But this might not be the scope of the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I initially implemented as two passes:
- Lower all nodes that can be lowered
- Do Fusion among nodes that are already lowered
1 works fine but fusing two tvm nodes does not quire work as is.
test/test_core.py
Outdated
ref_out, tvm_out = self.runBoth(linear, input, weight, bias) | ||
assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) | ||
|
||
# check to verify fustion still works |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
spell check on 'fusion'.
|
||
# check single node graph | ||
def linear(a, b, c): | ||
return F.linear(a, b, c) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you check if the test fails without your changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeap, it does indeed crash for single op version. For fusion one, I added to make sure future changes do not break it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a couple of comments. Also asked @bwasti to chime in.
Also maybe we can add an flag to enable this feature or not? Or maybe just check if it affects performance? |
torch_tvm/fusion_pass.cpp
Outdated
REQ((canHandle(consumer, aliasDb) || consumer->kind() == getTVMSymbol())); | ||
|
||
// if producer cannot be converted, check if consumer can be lowered to TVM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think we should put this kind of logic in parallel to tryMerge instead of inside. I actually think we should have some kind of flag to control minimal number of ops in order to create a fusion group. But this might not be the scope of the PR.
torch_tvm/fusion_pass.cpp
Outdated
// Already converted so return no change | ||
return c10::nullopt; | ||
} | ||
// prooceed to convert current node to TVM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo
4bfc690
to
b8eaeec
Compare
torch_tvm/fusion_pass.cpp
Outdated
if(!aliasDb.isMutable(consumer)){ | ||
REQ(!aliasDb.hasOutputWriters(consumer)); | ||
} | ||
consumer = SubgraphUtils::createSingletonSubgraph(consumer, getTVMSymbol()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also meant to suggest that we not only check if it can be lowered, but if it can be we create the singleton graph outside. Thus tryMerge always has consumer that is already lowered and it only does merging but never creation.
I think it will be cleaner this way. Sorry if it feels like I am dragging much want to drag it much.
Can you try the flow on the quantized model to see how many subgraph we could get now? |
We take PT JIT IR graph and fuse those nodes of the graph that can be lowered to Relay graph in TVM. However the current fusion logic does not create a subgraph out of a single node. You need at least two nodes that can be lowered and only then it will fuse them and create a subgraph.
In this PR, we want to enable subgraph creation for single nodes that cannot be fused with any other node. We achieve this by checking for the case where consumer can be lowered but producer cannot be.