Skip to content

Commit

Permalink
fix a few bugs with shape inference and types in the onnx importer (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart authored and Trevor Morris committed Jun 18, 2020
1 parent a133cc2 commit 885de3c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
4 changes: 3 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def _impl_v1(cls, inputs, attr, params):
def _impl_v5(cls, inputs, attr, params):
if get_name(inputs[1]) in params:
# pop shape out of parameters since it wont be needed later.
shape = tuple(params.pop(inputs[1].name_hint).asnumpy())
shape = tuple(params.pop(inputs[1].name_hint).asnumpy().astype("int32"))
out = _op.reshape(inputs[0], shape)
else:
data, shape = inputs
Expand Down Expand Up @@ -1485,6 +1485,8 @@ def _impl_v9(cls, inputs, attr, params):
raise ValueError("Expect 1 input only")

output = AttrCvt(op_name='argwhere')(inputs, attr, params)
# ONNX NonZero always outputs int64
output = _op.cast(output, "int64")
return _op.transpose(output, axes=(1, 0))

class TopK(OnnxOpConverter):
Expand Down
16 changes: 12 additions & 4 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -829,9 +829,13 @@ bool TakeRel(const Array<Type>& types,
// `types` contains: [data, indices, result]
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data != nullptr);
if (data == nullptr) {
return false;
}
const auto* indices = types[1].as<TensorTypeNode>();
CHECK(indices != nullptr);
if (indices == nullptr) {
return false;
}
CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer";
const auto param = attrs.as<TakeAttrs>();
CHECK(param != nullptr);
Expand Down Expand Up @@ -2325,15 +2329,19 @@ bool LayoutTransformRel(const Array<Type>& types,
const Attrs& attrs,
const TypeReporter& reporter) {
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data != nullptr);
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "LayoutTransform: expect input data type to be TensorType but get "
<< types[0];
return false;
}
const LayoutTransformAttrs* params = attrs.as<LayoutTransformAttrs>();

Layout src_layout(params->src_layout);
Layout dst_layout(params->dst_layout);

CHECK(src_layout.defined() && dst_layout.defined())
<< "cannot convert from/to undefined layout";

auto layout_converter = tir::BijectiveLayout(src_layout, dst_layout);
CHECK(layout_converter.defined())
<< "cannot convert from " << params->src_layout << " to " << params->dst_layout;
Expand Down
4 changes: 3 additions & 1 deletion src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,9 @@ bool ShapeOfRel(const Array<Type>& types,
const TypeReporter& reporter) {
CHECK_EQ(num_inputs, 1);
auto tt = types[0].as<TensorTypeNode>();
CHECK(tt != nullptr);
if (tt == nullptr) {
return false;
}
const auto* param = attrs.as<ShapeOfAttrs>();
CHECK(param != nullptr);
auto rank_shape = RankShape(tt->shape);
Expand Down

0 comments on commit 885de3c

Please sign in to comment.