Skip to content

Commit

Permalink
[MSC] Reconstruct tensorrt module (#17344)
Browse files Browse the repository at this point in the history
* reconstruct tensorrt

* format fix
  • Loading branch information
Archermmt authored Sep 8, 2024
1 parent dcd32ac commit 521ab47
Show file tree
Hide file tree
Showing 11 changed files with 642 additions and 283 deletions.
2 changes: 1 addition & 1 deletion python/tvm/contrib/msc/core/frontend/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def _is_target_func(func):
msc_mod = _partition_mod(mod)
func_names = [var.name_hint for var, func in msc_mod.functions.items() if _is_target_func(func)]

if not trans_config.get("allow_incomplete", False):
if trans_config.get("as_complete", True):
assert len(func_names) == 1, "More than 1 target func is found: " + str(msc_mod)
BYOCChecker().check(func_names, msc_mod[entry])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def transform_for_tensorrt(
return tvm.transform.Sequential(
[
msc_transform.SetExprName(),
trt_transform.TransformTensorRT(trans_config.get("version")),
trt_transform.TransformTensorRT(
version=trans_config.get("version"),
linear_to_conv=trans_config.get("linear_to_conv", False),
),
relax.transform.FoldConstant(),
]
)(mod)
Expand Down
31 changes: 21 additions & 10 deletions python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,22 @@ def _check_expr(expr: relax.Expr, dtypes: Tuple[str] = None) -> bool:
return True
if isinstance(expr, relax.Tuple):
return all(_check_expr(field) for field in expr.fields)
if any(i < 0 for i in expr.struct_info.shape.values):
return False
dtypes = dtypes or ("float32", "float16")
if expr.struct_info.dtype not in dtypes:
return False
return True
dtypes = dtypes or ("float32", "float16", "int64", "int32", "bool")

def _check(sinfo):
if not sinfo.shape or sinfo.dtype not in dtypes:
return False
unknown_dim = 0
for s in sinfo.shape.values:
if isinstance(s, (tvm.tir.Var, tvm.tir.Any)):
unknown_dim += 1
elif isinstance(s, tvm.tir.IntImm) and s < 0:
unknown_dim += 1
return unknown_dim <= 1

if isinstance(expr.struct_info, relax.TupleStructInfo):
return all(_check(s) for s in expr.struct_info.fields)
return _check(expr.struct_info)


def _basic_check(context: PatternCheckContext) -> bool:
Expand Down Expand Up @@ -216,8 +226,7 @@ def _reshape_check(context: PatternCheckContext) -> bool:
Whether the pattern is correct.
"""

dtypes = ("float32", "float16", "int32")
if any(not _check_expr(context.annotated_expr[key], dtypes) for key in ["input_0", "out"]):
if any(not _check_expr(context.annotated_expr[key]) for key in ["input_0", "out"]):
return False
return True

Expand Down Expand Up @@ -323,16 +332,18 @@ def get_patterns(target) -> List[Pattern]:
"nn.avg_pool2d": ["input"],
"nn.conv2d": ["input", "constant"],
"nn.max_pool2d": ["input"],
"astype": ["input"],
"concat": ["input"],
"clip": ["input", "input", "input"],
"image.resize2d": ["input", "input"],
"matmul": ["input", "input"],
"permute_dims": ["input"],
"strided_slice": ["input"],
"strided_slice": ["input", "input", "input", "input", "input"],
"topk": ["input"],
}
activation_ops = ["nn.relu", "nn.softmax", "sigmoid", "tanh"]
reduce_ops = ["max", "min", "mean", "sum"]
unary_ops = ["cos", "exp", "negative", "round", "sin", "square", "sqrt", "tan"]
unary_ops = ["cos", "erf", "exp", "negative", "round", "sin", "square", "sqrt", "tan"]
elemwise_ops = [
"add",
"divide",
Expand Down
13 changes: 10 additions & 3 deletions python/tvm/contrib/msc/framework/tensorrt/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,25 @@
from tvm.contrib.msc.core import utils as msc_utils


def TransformTensorRT(version: List[int] = None) -> tvm.ir.transform.Pass:
def TransformTensorRT(
version: List[int] = None, linear_to_conv: bool = False
) -> tvm.ir.transform.Pass:
"""Transform the Function to fit TensorRT.
Parameters
----------
version: list<int>
The tensorrt version.
linear_to_conv: bool
Whether to cast linear to conv2d
Returns
-------
ret: tvm.ir.transform.Pass
"""

version = version or msc_utils.get_version(MSCFramework.TENSORRT)
return relax_api.TransformTensorRT(version) # type: ignore
config = {
"version": version or msc_utils.get_version(MSCFramework.TENSORRT),
"linear_to_conv": linear_to_conv,
}
return relax_api.TransformTensorRT(msc_utils.dump_dict(config)) # type: ignore
58 changes: 58 additions & 0 deletions src/contrib/msc/core/transform/rewrite_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/contrib/msc/core/transform/rewrite_utils.cc
*/
#include "rewrite_utils.h"

#include <set>
#include <string>

namespace tvm {
namespace contrib {
namespace msc {

Var RewriteUtils::ReEmit(BlockBuilder builder, const String& name, const Expr& expr) {
expr->span = SpanUtils::SetAttr(expr->span, msc_attr::kName, name);
return builder->Emit(expr, name);
}

Var RewriteUtils::MakeCall(BlockBuilder builder, const String& name, Expr op, Array<Expr> args,
Attrs attrs) {
const auto& call = Call(op, args, attrs);
return ReEmit(builder, name, call);
}

Expr RewriteUtils::MakeConstant(BlockBuilder builder, const String& name, double value,
const DataType& dtype, size_t ndim) {
const auto& data = support::FloatImmToNDArray(FloatImm(dtype, value));
Span span = SpanUtils::CreateWithAttr(msc_attr::kName, name);
const auto& constant = Constant(data, NullOpt, span);
if (ndim == 0) {
return constant;
}
static const Op& reshape_op = Op::Get("relax.reshape");
Array<PrimExpr> exp_shape(ndim, Integer(1));
return MakeCall(builder, name + "_exp", reshape_op, {constant, ShapeExpr(exp_shape)});
}

} // namespace msc
} // namespace contrib
} // namespace tvm
72 changes: 72 additions & 0 deletions src/contrib/msc/core/transform/rewrite_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/contrib/msc/core/transform/rewrite_utils.h
* \brief Common utilities for rewrite.
*/
#ifndef TVM_CONTRIB_MSC_CORE_TRANSFORM_REWRITE_UTILS_H_
#define TVM_CONTRIB_MSC_CORE_TRANSFORM_REWRITE_UTILS_H_

#include <tvm/ir/source_map.h>
#include <tvm/relax/expr.h>

#include <vector>

#include "../../../../relax/transform/utils.h"
#include "../../../../support/scalars.h"
#include "../utils.h"

namespace tvm {
namespace contrib {
namespace msc {

using Expr = tvm::RelayExpr;
using namespace tvm::relax;

/*!
* \brief Utils for Layout.
*/
class RewriteUtils {
public:
/*!
* \brief Emit call with span name.
* \return The emitted var.
*/
TVM_DLL static Var ReEmit(BlockBuilder builder, const String& name, const Expr& expr);

/*!
* \brief Make and emit a call binding with span.
* \return The emitted var.
*/
TVM_DLL static Var MakeCall(BlockBuilder builder, const String& name, Expr op, Array<Expr> args,
Attrs attrs = Attrs());

/*!
* \brief Make and emit a (shaped)constant with span.
* \return The constant/reshape.
*/
TVM_DLL static Expr MakeConstant(BlockBuilder builder, const String& name, double value,
const DataType& dtype, size_t ndim = 0);
};

} // namespace msc
} // namespace contrib
} // namespace tvm
#endif // TVM_CONTRIB_MSC_CORE_TRANSFORM_REWRITE_UTILS_H_
19 changes: 16 additions & 3 deletions src/contrib/msc/core/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -507,12 +507,25 @@ const String ExprUtils::GetSpanName(const Expr& expr, const String& suffix) {
return name;
}

const Array<PrimExpr> ExprUtils::GetShape(const Expr& expr) {
const auto& shape_opt = Downcast<relax::TensorStructInfo>(relax::GetStructInfo(expr))->GetShape();
ICHECK(shape_opt.defined()) << "Shape is not defined for " << expr;
const Array<PrimExpr> ExprUtils::GetShape(const relax::TensorStructInfo& sinfo, bool as_int) {
const auto& shape_opt = sinfo->GetShape();
if (!shape_opt.defined()) {
return Array<PrimExpr>();
}
if (as_int) {
Array<PrimExpr> shape;
for (const auto& s : shape_opt.value()) {
shape.push_back(s->IsInstance<IntImmNode>() ? s : Integer(-1));
}
return shape;
}
return shape_opt.value();
}

const Array<PrimExpr> ExprUtils::GetShape(const Expr& expr, bool as_int) {
return GetShape(Downcast<relax::TensorStructInfo>(relax::GetStructInfo(expr)), as_int);
}

const DataType ExprUtils::GetDataType(const Expr& expr) {
return Downcast<relax::TensorStructInfo>(relax::GetStructInfo(expr))->dtype;
}
Expand Down
4 changes: 3 additions & 1 deletion src/contrib/msc/core/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,9 @@ class ExprUtils {
* \brief Get shape of expr.
* \return The shape.
*/
TVM_DLL static const Array<PrimExpr> GetShape(const Expr& expr);
TVM_DLL static const Array<PrimExpr> GetShape(const relax::TensorStructInfo& sinfo,
bool as_int = true);
TVM_DLL static const Array<PrimExpr> GetShape(const Expr& expr, bool as_int = true);

/*!
* \brief Get dtype of expr.
Expand Down
6 changes: 4 additions & 2 deletions src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ const String TensorRTOpCode::DType(const DataType& dtype) {
dtype_enum = "DataType::kINT8";
} else if (dtype_name == "int32") {
dtype_enum = "DataType::kINT32";
} else if (dtype_name == "int64") {
dtype_enum = "DataType::kINT32";
} else if (dtype_name == "float16") {
dtype_enum = "DataType::kHALF";
} else if (dtype_name == "float32") {
Expand Down Expand Up @@ -267,7 +269,7 @@ class TensorRTAstypeCodeGen : public TensorRTOpCode {
void CodeGenBuild() final {
stack_.op_call()
.op_input_arg()
.func_call("setOutput", NullOpt, DocUtils::ToPtr(IdxNode()))
.func_call("setOutputType", NullOpt, DocUtils::ToPtr(IdxNode()))
.call_arg(0)
.op_dtype_arg(node()->OutputAt(0)->dtype);
}
Expand Down Expand Up @@ -661,7 +663,7 @@ class TensorRTTopkCodeGen : public TensorRTOpCode {

protected:
void CodeGenBuild() final {
const String& symbol = node()->GetTypeAttr<bool>("is_asend") ? "MIN" : "MAX";
const String& symbol = node()->GetTypeAttr<bool>("largest") ? "MAX" : "MIN";
stack_.op_call()
.op_input_arg()
.call_arg("TopKOperation::k" + symbol)
Expand Down
Loading

0 comments on commit 521ab47

Please sign in to comment.