From e43d1f197d222c48d4bf0fbdb06fff75e19419b8 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 30 Sep 2019 10:06:35 -0700 Subject: [PATCH] [Relay][Compile_engine] Int64 shape handling for outputs. (#4031) --- src/relay/backend/compile_engine.cc | 25 ++++++++++++++++--- .../relay/test_backend_compile_engine.py | 15 +++++++++++ 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index c88703ea4d05..a75cdb299bf4 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -219,6 +219,25 @@ class ScheduleGetter : CHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; } + + // Prepare the call_node->checked_type(). For the call node inputs, we ensure that the shape is + // Int32. Following code ensures the same for the output as well. + // TODO(@icemelon): Support recursive tuple + Type call_node_type = call_node->checked_type(); + if (const auto* tt = call_node->checked_type().as()) { + call_node_type = TensorTypeNode::make(GetShape(tt->shape), tt->dtype); + } else if (const auto* tuple_t = call_node->checked_type().as()) { + std::vector new_fields; + for (auto field : tuple_t->fields) { + if (const auto* tt = field.as()) { + new_fields.push_back(TensorTypeNode::make(GetShape(tt->shape), tt->dtype)); + } else { + new_fields.push_back(field); + } + } + call_node_type = TupleTypeNode::make(new_fields); + } + CHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); @@ -232,7 +251,7 @@ class ScheduleGetter : Operation(), 0)); } else { outputs = fcompute[op](call_node->attrs, inputs, - call_node->checked_type(), target_); + call_node_type, target_); } int op_pattern = fpattern[op]; diff --git a/tests/python/relay/test_backend_compile_engine.py b/tests/python/relay/test_backend_compile_engine.py index ea16a8d6122e..b1f41a43148c 100644 --- a/tests/python/relay/test_backend_compile_engine.py +++ b/tests/python/relay/test_backend_compile_engine.py @@ -79,8 +79,23 @@ def test_compile_tuple_dup(): relay.build(relay.Module.from_expr(f), 'llvm') +def test_compile_full(): + # Shape calculations can happen in int64. The test checks that full operator + # can handle when shapes are not int32 + shape = (tvm.expr.IntImm('int32', 1), + tvm.expr.IntImm('int64', 16), + tvm.expr.IntImm('int64', 16), + tvm.expr.IntImm('int32', 64)) + output = relay.full(relay.const(0, 'int32'), shape=shape, dtype='int32') + f = relay.Function([], output) + mod = relay.Module.from_expr(f) + mod = relay.qnn.transform.CanonicalizeOps()(mod) + relay.build(mod, 'llvm') + + if __name__ == "__main__": test_compile_engine() test_compile_placeholder_bypass() test_compile_injective_with_tuple() test_compile_tuple_dup() + test_compile_full()