diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index c6e74d42843d..1e846b0b3a61 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -507,12 +507,25 @@ const String ExprUtils::GetSpanName(const Expr& expr, const String& suffix) { return name; } -const Array ExprUtils::GetShape(const Expr& expr) { - const auto& shape_opt = Downcast(relax::GetStructInfo(expr))->GetShape(); - ICHECK(shape_opt.defined()) << "Shape is not defined for " << expr; +const Array ExprUtils::GetShape(const relax::TensorStructInfo& sinfo, bool as_int) { + const auto& shape_opt = sinfo->GetShape(); + if (!shape_opt.defined()) { + return Array(); + } + if (as_int) { + Array shape; + for (const auto& s : shape_opt.value()) { + shape.push_back(s->IsInstance() ? s : Integer(-1)); + } + return shape; + } return shape_opt.value(); } +const Array ExprUtils::GetShape(const Expr& expr, bool as_int) { + return GetShape(Downcast(relax::GetStructInfo(expr)), as_int); +} + const DataType ExprUtils::GetDataType(const Expr& expr) { return Downcast(relax::GetStructInfo(expr))->dtype; } diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h index d7758cc23d8b..7fb9c87a99f9 100644 --- a/src/contrib/msc/core/utils.h +++ b/src/contrib/msc/core/utils.h @@ -398,7 +398,9 @@ class ExprUtils { * \brief Get shape of expr. * \return The shape. */ - TVM_DLL static const Array GetShape(const Expr& expr); + TVM_DLL static const Array GetShape(const relax::TensorStructInfo& sinfo, + bool as_int = true); + TVM_DLL static const Array GetShape(const Expr& expr, bool as_int = true); /*! * \brief Get dtype of expr. diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc index 0f95f2d20622..542e15d06c3c 100644 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -90,7 +90,7 @@ const Array BroadcastShape(const Array& src_shape, ICHECK(ArrayUtils::Broadcastable(leading_shape, out_shape)) << "Only support elemwise ops with leading or tailing expand"; return leading_shape; -}; +} Expr RewriteElemwise(BlockBuilder builder, const Var& var, const Call& src_call, const Map& new_calls, const String& config) {