Skip to content

Commit

Permalink
format fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Archermmt committed Sep 7, 2024
1 parent bf56c82 commit 4098908
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
19 changes: 16 additions & 3 deletions src/contrib/msc/core/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -507,12 +507,25 @@ const String ExprUtils::GetSpanName(const Expr& expr, const String& suffix) {
return name;
}

const Array<PrimExpr> ExprUtils::GetShape(const Expr& expr) {
const auto& shape_opt = Downcast<relax::TensorStructInfo>(relax::GetStructInfo(expr))->GetShape();
ICHECK(shape_opt.defined()) << "Shape is not defined for " << expr;
const Array<PrimExpr> ExprUtils::GetShape(const relax::TensorStructInfo& sinfo, bool as_int) {
const auto& shape_opt = sinfo->GetShape();
if (!shape_opt.defined()) {
return Array<PrimExpr>();
}
if (as_int) {
Array<PrimExpr> shape;
for (const auto& s : shape_opt.value()) {
shape.push_back(s->IsInstance<IntImmNode>() ? s : Integer(-1));
}
return shape;
}
return shape_opt.value();
}

const Array<PrimExpr> ExprUtils::GetShape(const Expr& expr, bool as_int) {
return GetShape(Downcast<relax::TensorStructInfo>(relax::GetStructInfo(expr)), as_int);
}

const DataType ExprUtils::GetDataType(const Expr& expr) {
return Downcast<relax::TensorStructInfo>(relax::GetStructInfo(expr))->dtype;
}
Expand Down
4 changes: 3 additions & 1 deletion src/contrib/msc/core/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,9 @@ class ExprUtils {
* \brief Get shape of expr.
* \return The shape.
*/
TVM_DLL static const Array<PrimExpr> GetShape(const Expr& expr);
TVM_DLL static const Array<PrimExpr> GetShape(const relax::TensorStructInfo& sinfo,
bool as_int = true);
TVM_DLL static const Array<PrimExpr> GetShape(const Expr& expr, bool as_int = true);

/*!
* \brief Get dtype of expr.
Expand Down
2 changes: 1 addition & 1 deletion src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ const Array<PrimExpr> BroadcastShape(const Array<PrimExpr>& src_shape,
ICHECK(ArrayUtils::Broadcastable(leading_shape, out_shape))
<< "Only support elemwise ops with leading or tailing expand";
return leading_shape;
};
}

Expr RewriteElemwise(BlockBuilder builder, const Var& var, const Call& src_call,
const Map<Expr, Call>& new_calls, const String& config) {
Expand Down

0 comments on commit 4098908

Please sign in to comment.