Skip to content

Commit

Permalink
feat: Enable TRT 8.0 QAT functionality in TRTorch
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 committed Jul 14, 2021
1 parent 5708634 commit c76a28a
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 20 deletions.
3 changes: 2 additions & 1 deletion core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
input_type = nvinfer1::DataType::kFLOAT;
// TRTORCH_CHECK(
// settings.calibrator != nullptr,
// "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the CompileSpec struct with your calibrator");
// "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the CompileSpec
// struct with your calibrator");
// cfg->setInt8Calibrator(settings.calibrator);
break;
case nvinfer1::DataType::kFLOAT:
Expand Down
20 changes: 11 additions & 9 deletions core/conversion/converters/impl/conv_deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,23 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
if (args[2].IValue()->isTensor()) {
bias = Weights(ctx, args[2].unwrapToTensor());
} else {
bias = Weights(); //nvinfer1::Weights{nvinfer1::DataType::kFLOAT, nullptr, 0};
bias = Weights(); // nvinfer1::Weights{nvinfer1::DataType::kFLOAT, nullptr, 0};
}

// Handle case when weights of conv/deconv is an ITensor. This case happens for QAT networks where
// conv_weights -> Quantize -> Dequantize -> new_conv_weights -> conv <- input
// new_conv_weights will be an ITensor because it is an output of Dequantize layer defined in impl/quantization.cpp
if (args[1].isITensor()){
if (args[1].isITensor()) {
// Get the kernel tensor
auto kernel = args[1].ITensor();
auto kernel_dims = kernel->getDimensions();

// Make a new Dims with only the spatial dimensions.
nvinfer1::Dims filter_dim;
int64_t nbSpatialDims = in->getDimensions().nbDims - 2;
TRTORCH_CHECK(nbSpatialDims = kernel_dims.nbDims - 2, "Number of input spatial dimensions should match the kernel spatial dimensions");
TRTORCH_CHECK(
nbSpatialDims = kernel_dims.nbDims - 2,
"Number of input spatial dimensions should match the kernel spatial dimensions");
filter_dim.nbDims = nbSpatialDims;
filter_dim.d[0] = kernel_dims.d[2];
filter_dim.d[1] = kernel_dims.d[3];
Expand All @@ -68,9 +70,9 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
auto kernel_weights = nvinfer1::Weights{nvinfer1::DataType::kFLOAT, nullptr, 0};

nvinfer1::ILayer* layer = nullptr;
if (transposed){
nvinfer1::IDeconvolutionLayer* deconvLayer
= ctx->net->addDeconvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data);
if (transposed) {
nvinfer1::IDeconvolutionLayer* deconvLayer =
ctx->net->addDeconvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data);
deconvLayer->setStrideNd(stride);
deconvLayer->setDilationNd(dilation);
deconvLayer->setNbGroups(groups);
Expand All @@ -79,9 +81,9 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
deconvLayer->setInput(1, *kernel);
TRTORCH_CHECK(deconvLayer, "Unable to create deconv layer with non-const weights from node: " << *n);
layer = deconvLayer;
} else{
nvinfer1::IConvolutionLayer* convLayer
= ctx->net->addConvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data);
} else {
nvinfer1::IConvolutionLayer* convLayer =
ctx->net->addConvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data);
convLayer->setStrideNd(stride);
convLayer->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN);
convLayer->setPaddingNd(padding);
Expand Down
9 changes: 5 additions & 4 deletions core/conversion/converters/impl/linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,17 @@ auto linear_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().patt

// Get the bias
Weights bias;
if(!args[2].IValue()->isNone()){
if (!args[2].IValue()->isNone()) {
bias = Weights(ctx, args[2].IValue()->toTensor());
}else {
} else {
bias = Weights();
}

// Handle case when weights of conv/deconv is an ITensor. This case happens for QAT networks where
// conv_weights -> Quantize -> Dequantize -> new_conv_weights -> conv <- input
// new_conv_weights will be an ITensor because it is an output of Dequantize layer defined in impl/quantization.cpp
if(args[1].isITensor()){
// new_conv_weights will be an ITensor because it is an output of Dequantize layer defined in
// impl/quantization.cpp
if (args[1].isITensor()) {
auto kernel_tensor = args[1].ITensor();
auto kernel_dims = args[1].ITensor()->getDimensions();
// Initialize a dummy constant kernel to pass it to INetwork->addConvolutionNd/addDeconvolutionNd API.
Expand Down
55 changes: 55 additions & 0 deletions core/conversion/converters/impl/matrix_multiply.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <torch/torch.h>
#include "core/conversion/converters/converter_util.h"
#include "core/conversion/converters/converters.h"
#include "core/util/prelude.h"
Expand Down Expand Up @@ -72,6 +73,60 @@ auto mm_registrations TRTORCH_UNUSED =

LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
return true;
}})
.pattern(
{"aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
auto mat1 = args[1].ITensorOrFreeze(ctx);
auto mat2 = args[2].ITensorOrFreeze(ctx);
auto beta = args[4].unwrapToScalar().to<float>();
auto betaTensor = tensor_to_const(ctx, torch::tensor({beta}));
auto alpha = args[5].unwrapToScalar().to<float>();
auto alphaTensor = tensor_to_const(ctx, torch::tensor({alpha}));

// Ensure self and other tensors have same nbDims by expanding the dimensions (from 0 axis) if
// necessary.
if (mat1->getDimensions().nbDims < mat2->getDimensions().nbDims) {
mat1 = addPadding(ctx, n, mat1, mat2->getDimensions().nbDims, false, false);
} else {
mat2 = addPadding(ctx, n, mat2, mat1->getDimensions().nbDims, false, false);
}

auto mat2_dims = mat2->getDimensions();
nvinfer1::Dims transposed_mat2_dims;
for (int i = mat2_dims.nbDims - 1; i >= 0; i--) {
transposed_mat2_dims.d[i] = mat2_dims.d[mat2_dims.nbDims - 1 - i];
}
auto shuffle_layer = ctx->net->addShuffle(*mat2);
shuffle_layer->setReshapeDimensions(transposed_mat2_dims);
mat2 = shuffle_layer->getOutput(0);

auto mm_layer = ctx->net->addMatrixMultiply(
*mat1, nvinfer1::MatrixOperation::kNONE, *mat2, nvinfer1::MatrixOperation::kNONE);
TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication layer in node: " << *n);
auto mm_scale_layer = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
mm_layer->getOutput(0),
alphaTensor,
util::node_info(n) + "_alphaScale");
TRTORCH_CHECK(mm_scale_layer, "Unable to create alpha scaling layer in node: " << *n);
auto beta_scale_layer = add_elementwise(
ctx, nvinfer1::ElementWiseOperation::kPROD, self, betaTensor, util::node_info(n) + "_betaScale");
TRTORCH_CHECK(beta_scale_layer, "Unable to create beta scaling layer in node: " << *n);
auto add_mm_layer = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kSUM,
beta_scale_layer->getOutput(0),
mm_scale_layer->getOutput(0),
util::node_info(n));
TRTORCH_CHECK(add_mm_layer, "Unable to create addmm layer in node: " << *n);

auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], add_mm_layer->getOutput(0));

LOG_DEBUG("[AddMM layer] Output tensor shape: " << out_tensor->getDimensions());
return true;
}});
} // namespace
} // namespace impl
Expand Down
7 changes: 6 additions & 1 deletion core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,9 +430,14 @@ auto aten_registrations TRTORCH_UNUSED =
.evaluator({c10::Symbol::fromQualString("aten::t"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto tensor_var = args.at(n->input(0));
if (tensor_var.IValue()->isTensor()) {
if (tensor_var.isIValue() && tensor_var.IValue()->isTensor()) {
auto tensor = tensor_var.unwrapToTensor();
return tensor.t();
} else if (tensor_var.isITensor()) {
auto tensor_holder = TensorContainer();
tensor_holder.hold_tensor(tensor_var.ITensor());
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
return ival;
} else {
TRTORCH_THROW_ERROR("Unimplemented data type for aten::t evaluator: ITensor");
return {};
Expand Down
4 changes: 2 additions & 2 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ torch::jit::Module LowerModule(const torch::jit::script::Module& mod) {
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower(
const torch::jit::script::Module& mod,
std::string method_name) {
auto lowered_mod = LowerModule(mod);
auto lowered_mod = mod; // LowerModule(mod);
auto g = lowered_mod.get_method(method_name).graph();
LOG_GRAPH(*g);
LOG_INFO(*g);

// Go through TRTorch Lowering to reformat graph to be conversion friendly
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
Expand Down
3 changes: 0 additions & 3 deletions core/plugins/impl/interpolate_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ std::vector<int64_t> InterpolatePlugin::getOutputSize() {
return size_;
}


int InterpolatePlugin::getNbOutputs() const noexcept {
if (mode_ == "adaptive_max_pool2d") {
return 2;
Expand Down Expand Up @@ -170,7 +169,6 @@ nvinfer1::DataType InterpolatePlugin::getOutputDataType(int index, const nvinfer
return nvinfer1::DataType::kFLOAT;
}


int InterpolatePlugin::initialize() noexcept {
return 0;
}
Expand Down Expand Up @@ -208,7 +206,6 @@ bool InterpolatePlugin::supportsFormatCombination(
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) noexcept {

TRTORCH_ASSERT(nbInputs == 1, "Expected a single tensor as input to interpolate plugin");

if (mode_ == "adaptive_max_pool2d") {
Expand Down

0 comments on commit c76a28a

Please sign in to comment.