diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 1b27dea28825..73899097fe52 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -57,6 +57,11 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, } else { if (weight == nullptr) return false; Array wshape = weight->shape; + CHECK(static_cast(weight->shape.size()) == 2); + CHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], + weight->shape[1])) + << "DenseRel: input dimension doesn't match," + << " data shape=" << data->shape << ", weight shape=" << weight->shape; oshape.Set((oshape.size() - 1), wshape[0]); } diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 6723369f3886..2c8b89695c73 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import pytest import tvm import scipy from tvm import relay @@ -336,6 +337,16 @@ def test_batch_norm(): relay.ty.TensorType((3,), dtype) ])) +@pytest.mark.xfail +def test_dense_type_check(): + dtype = 'float16' + n, c , h, w = 2, 2 , 2 ,2 + x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) + # it should fail since it does not match with m(2) + mismatch_w = 3 + w = relay.var("w", relay.TensorType((2, mismatch_w), dtype)) + y = relay.nn.dense(x, w) + yy = run_infer_type(y) def test_dense(): for dtype in ['float16', 'float32']: