Skip to content
This repository has been archived by the owner on Apr 1, 2021. It is now read-only.

Commit

Permalink
[torch_tvm] Support Lowering to TVM even if node cannot be fused
Browse files Browse the repository at this point in the history
  • Loading branch information
serhaty committed Sep 26, 2019
1 parent 16ca92c commit b8eaeec
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
28 changes: 28 additions & 0 deletions test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from test.util import TVMTest
import torch
import torch_tvm
import torch.nn.functional as F


class TestCore(TVMTest):
Expand Down Expand Up @@ -135,5 +136,32 @@ def dropout_inference(a, b, c):
str(tvm_graph_inference.graph_for(input_a, input_b, input_c)), \
"dropout must be removed during inference."

@TVMTest.given(
shape=TVMTest.rand_shape(rank=2, min_dim=4),
out_features=TVMTest.rand_int(3, 6),
)
def test_fuse_single_node(self, shape, out_features):
print("Running test for test_fuse_single_node")
input = torch.rand(shape)
weight = torch.rand(out_features, shape[1])
bias = torch.rand(out_features)

# check single node graph
def linear(a, b, c):
return F.linear(a, b, c)

ref_out, tvm_out = self.runBoth(linear, input, weight, bias)
assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01)

# check to verify fusion still works
def linearSum(a, b, c):
return F.linear(a, b, c) + 2.0

ref_out, tvm_out = self.runBoth(linearSum, input, weight, bias)
assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01)




if __name__ == "__main__":
unittest.main()
26 changes: 23 additions & 3 deletions torch_tvm/fusion_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ bool canHandle(Block* block, AliasDb& aliasDb) {
return true;
}

bool canLowerSeed(Node* node, AliasDb& aliasDb) {
return canHandle(node, aliasDb) || node->kind() == getTVMSymbol();
}

#define REQ(cond) \
if (!(cond)) { \
GRAPH_DEBUG("Failed cond " #cond "\n"); \
Expand All @@ -59,9 +63,22 @@ c10::optional<Node*> tryMerge(
consumer->kind().toQualString(),
":\n");

// Symbolic checks
REQ(canHandle(producer, aliasDb));
REQ((canHandle(consumer, aliasDb) || consumer->kind() == getTVMSymbol()));
// if producer cannot be converted, lower consumer to TVM
if(!canHandle(producer, aliasDb)){
if(consumer->kind() == getTVMSymbol() || consumer->hasAttribute(attr::Subgraph)) {
// Already converted so return no change
return c10::nullopt;
}
// proceed to convert current node to TVM
if(!aliasDb.isMutable(consumer)){
REQ(!aliasDb.hasOutputWriters(consumer));
}
consumer = SubgraphUtils::createSingletonSubgraph(consumer, getTVMSymbol());

return consumer;
}

// Nodes can be fused

// Alias checks
// Requirement:
Expand Down Expand Up @@ -107,6 +124,9 @@ std::pair<graph_node_list::iterator, bool> scanNode(
Block* block) {
auto inputs = sortReverseTopological(consumer->inputs(), block);
for (auto input : inputs) {
if(!canLowerSeed(consumer, aliasDb)) {
continue;
}
if (auto group = tryMerge(consumer, input->node(), aliasDb)) {
// we successfully merged, so the new group's `inputs` may have
// changed. So rescan the new group for more merging opportunities.
Expand Down

0 comments on commit b8eaeec

Please sign in to comment.