Skip to content

Commit

Permalink
[Relay][Compile_engine] Int64 shape handling for outputs. (apache#4031)
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 authored and wweic committed Oct 18, 2019
1 parent de97cd2 commit e43d1f1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
25 changes: 22 additions & 3 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -219,6 +219,25 @@ class ScheduleGetter :
CHECK_EQ(call_node->args.size(), 1U)
<< "Only allow function with a single tuple input";
}

// Prepare the call_node->checked_type(). For the call node inputs, we ensure that the shape is
// Int32. Following code ensures the same for the output as well.
// TODO(@icemelon): Support recursive tuple
Type call_node_type = call_node->checked_type();
if (const auto* tt = call_node->checked_type().as<TensorTypeNode>()) {
call_node_type = TensorTypeNode::make(GetShape(tt->shape), tt->dtype);
} else if (const auto* tuple_t = call_node->checked_type().as<TupleTypeNode>()) {
std::vector<Type> new_fields;
for (auto field : tuple_t->fields) {
if (const auto* tt = field.as<TensorTypeNode>()) {
new_fields.push_back(TensorTypeNode::make(GetShape(tt->shape), tt->dtype));
} else {
new_fields.push_back(field);
}
}
call_node_type = TupleTypeNode::make(new_fields);
}

CHECK(call_node->op.as<OpNode>())
<< "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);
Expand All @@ -232,7 +251,7 @@ class ScheduleGetter :
Operation(), 0));
} else {
outputs = fcompute[op](call_node->attrs, inputs,
call_node->checked_type(), target_);
call_node_type, target_);
}

int op_pattern = fpattern[op];
Expand Down
15 changes: 15 additions & 0 deletions tests/python/relay/test_backend_compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,23 @@ def test_compile_tuple_dup():
relay.build(relay.Module.from_expr(f), 'llvm')


def test_compile_full():
# Shape calculations can happen in int64. The test checks that full operator
# can handle when shapes are not int32
shape = (tvm.expr.IntImm('int32', 1),
tvm.expr.IntImm('int64', 16),
tvm.expr.IntImm('int64', 16),
tvm.expr.IntImm('int32', 64))
output = relay.full(relay.const(0, 'int32'), shape=shape, dtype='int32')
f = relay.Function([], output)
mod = relay.Module.from_expr(f)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
relay.build(mod, 'llvm')


if __name__ == "__main__":
test_compile_engine()
test_compile_placeholder_bypass()
test_compile_injective_with_tuple()
test_compile_tuple_dup()
test_compile_full()

0 comments on commit e43d1f1

Please sign in to comment.