From 4bfc690bafea5ad1c339a5b2ffa938dde4a66cc6 Mon Sep 17 00:00:00 2001 From: serhaty Date: Mon, 23 Sep 2019 15:48:56 -0700 Subject: [PATCH] [torch_tvm] Support Lowering to TVM even if node cannot be fused --- test/test_core.py | 28 ++++++++++++++++++++++++++++ torch_tvm/fusion_pass.cpp | 26 +++++++++++++++++++++++--- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/test/test_core.py b/test/test_core.py index 7fce71c..b7cc143 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -2,6 +2,7 @@ from test.util import TVMTest import torch import torch_tvm +import torch.nn.functional as F class TestCore(TVMTest): @@ -135,5 +136,32 @@ def dropout_inference(a, b, c): str(tvm_graph_inference.graph_for(input_a, input_b, input_c)), \ "dropout must be removed during inference." + @TVMTest.given( + shape=TVMTest.rand_shape(rank=2, min_dim=4), + out_features=TVMTest.rand_int(3, 6), + ) + def test_fuse_single_node(self, shape, out_features): + print("Running test for test_fuse_single_node") + input = torch.rand(shape) + weight = torch.rand(out_features, shape[1]) + bias = torch.rand(out_features) + + # check single node graph + def linear(a, b, c): + return F.linear(a, b, c) + + ref_out, tvm_out = self.runBoth(linear, input, weight, bias) + assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) + + # check to verify fusion still works + def linearSum(a, b, c): + return F.linear(a, b, c) + 2.0 + + ref_out, tvm_out = self.runBoth(linearSum, input, weight, bias) + assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) + + + + if __name__ == "__main__": unittest.main() diff --git a/torch_tvm/fusion_pass.cpp b/torch_tvm/fusion_pass.cpp index 1dee51c..8fb14fa 100644 --- a/torch_tvm/fusion_pass.cpp +++ b/torch_tvm/fusion_pass.cpp @@ -43,6 +43,10 @@ bool canHandle(Block* block, AliasDb& aliasDb) { return true; } +bool canLowerSeed(Node* node, AliasDb& aliasDb) { + return canHandle(node, aliasDb) || node->kind() == getTVMSymbol(); +} + #define REQ(cond) \ if (!(cond)) { \ GRAPH_DEBUG("Failed cond " #cond "\n"); \ @@ -59,9 +63,22 @@ c10::optional tryMerge( consumer->kind().toQualString(), ":\n"); - // Symbolic checks - REQ(canHandle(producer, aliasDb)); - REQ((canHandle(consumer, aliasDb) || consumer->kind() == getTVMSymbol())); + // if producer cannot be converted, check if consumer can be lowered to TVM + if(!canHandle(producer, aliasDb)){ + if(consumer->kind() == getTVMSymbol() || consumer->hasAttribute(attr::Subgraph)) { + // Already converted so return no change + return c10::nullopt; + } + // proceed to convert current node to TVM + if(!aliasDb.isMutable(consumer)){ + REQ(!aliasDb.hasOutputWriters(consumer)); + } + consumer = SubgraphUtils::createSingletonSubgraph(consumer, getTVMSymbol()); + + return consumer; + } + + // Nodes can be fused // Alias checks // Requirement: @@ -107,6 +124,9 @@ std::pair scanNode( Block* block) { auto inputs = sortReverseTopological(consumer->inputs(), block); for (auto input : inputs) { + if(!canLowerSeed(consumer, aliasDb)) { + continue; + } if (auto group = tryMerge(consumer, input->node(), aliasDb)) { // we successfully merged, so the new group's `inputs` may have // changed. So rescan the new group for more merging opportunities.