diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 2356634c4ed0..ae256629f3b1 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -658,7 +658,7 @@ bool BatchMatmulRel(const Array& types, const auto* x = types[0].as(); const auto* y = types[1].as(); if (x == nullptr || y == nullptr) return false; - if (x->shape.size() != 3 || y->shape.size() != 3) return false; + CHECK(x->shape.size() == 3 && y->shape.size() == 3); CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) << "BatchDot: batch dimension doesn't match, " << " x shape=" << x->shape