Skip to content

Commit

Permalink
fix: Final working version of QAT in TRTorch
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 committed Jul 26, 2021
1 parent 715120f commit 521a0cb
Show file tree
Hide file tree
Showing 16 changed files with 92 additions and 117 deletions.
6 changes: 3 additions & 3 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ void AddEngineToGraph(

bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name) {
// Go through Lowering to simplify graph and extract weight parameters
auto graph_and_parameters = lowering::Lower(mod, method_name);
auto graph_and_parameters = lowering::Lower(mod, method_name, false);

auto g = graph_and_parameters.first;
LOG_DEBUG(*g << "(CheckMethodOperatorSupport)\n");
Expand All @@ -129,7 +129,7 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::stri

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
// Go through Lowering to simplify graph and extract weight parameters
auto graph_and_parameters = lowering::Lower(mod, method_name);
auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.convert_info.engine_settings.unfreeze_module);

auto convert_cfg = std::move(cfg.convert_info);
auto g = graph_and_parameters.first;
Expand Down Expand Up @@ -187,7 +187,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
// Compile only forward methods. forward method contains the entire graph.
if (method.name().compare("forward") == 0) {
auto new_g = std::make_shared<torch::jit::Graph>();
auto graph_and_parameters = lowering::Lower(mod, method.name());
auto graph_and_parameters = lowering::Lower(mod, method.name(), cfg.convert_info.engine_settings.unfreeze_module);

auto g = graph_and_parameters.first;
auto params = graph_and_parameters.second;
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
if (!settings.calibrator) {
LOG_WARNING(
"Int8 precision has been enabled but no calibrator provided. This assumes the network has Q/DQ nodes obtained from Quantization aware training. For more details, refer to https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#work-with-qat-networks");
} else{
} else {
cfg->setInt8Calibrator(settings.calibrator);
}
break;
Expand Down
2 changes: 2 additions & 0 deletions core/conversion/conversionctx/ConversionCtx.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ struct BuilderSettings {
bool sparse_weights = false;
std::set<nvinfer1::DataType> enabled_precisions = {nvinfer1::DataType::kFLOAT};
bool disable_tf32 = false;
// Internal flag to ensure torch.jit.Module does not get freezed in lowering.cpp. This is required for QAT models.
bool unfreeze_module = false;
bool refit = false;
bool debug = false;
bool strict_types = false;
Expand Down
46 changes: 1 addition & 45 deletions core/conversion/converters/impl/matrix_multiply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ auto mm_registrations TRTORCH_UNUSED =

auto mm_layer = ctx->net->addMatrixMultiply(
*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE);

TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n);
mm_layer->setName(util::node_info(n).c_str());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));
Expand Down Expand Up @@ -73,51 +74,6 @@ auto mm_registrations TRTORCH_UNUSED =

LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
return true;
}})
.pattern(
{"aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
auto mat1 = args[1].ITensorOrFreeze(ctx);
auto mat2 = args[2].ITensorOrFreeze(ctx);
auto beta = args[3].unwrapToScalar().to<float>();
auto betaTensor = tensor_to_const(ctx, torch::tensor({beta}));
auto alpha = args[4].unwrapToScalar().to<float>();
auto alphaTensor = tensor_to_const(ctx, torch::tensor({alpha}));

// Ensure self and other tensors have same nbDims by expanding the dimensions (from 0 axis) if
// necessary.
if (mat1->getDimensions().nbDims < mat2->getDimensions().nbDims) {
mat1 = addPadding(ctx, n, mat1, mat2->getDimensions().nbDims, false, false);
} else {
mat2 = addPadding(ctx, n, mat2, mat1->getDimensions().nbDims, false, false);
}

auto mm_layer = ctx->net->addMatrixMultiply(
*mat1, nvinfer1::MatrixOperation::kNONE, *mat2, nvinfer1::MatrixOperation::kNONE);
TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication layer in node: " << *n);
auto mm_scale_layer = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
mm_layer->getOutput(0),
alphaTensor,
util::node_info(n) + "_alphaScale");
TRTORCH_CHECK(mm_scale_layer, "Unable to create alpha scaling layer in node: " << *n);
auto beta_scale_layer = add_elementwise(
ctx, nvinfer1::ElementWiseOperation::kPROD, self, betaTensor, util::node_info(n) + "_betaScale");
TRTORCH_CHECK(beta_scale_layer, "Unable to create beta scaling layer in node: " << *n);
auto add_mm_layer = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kSUM,
beta_scale_layer->getOutput(0),
mm_scale_layer->getOutput(0),
util::node_info(n));
TRTORCH_CHECK(add_mm_layer, "Unable to create addmm layer in node: " << *n);

auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));

LOG_DEBUG("[AddMM layer] Output tensor shape: " << out_tensor->getDimensions());
return true;
}});
} // namespace
} // namespace impl
Expand Down
19 changes: 13 additions & 6 deletions core/conversion/converters/impl/shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,22 @@ static auto shuffle_registrations TRTORCH_UNUSED =
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensorOrFreeze(ctx);
auto input_dims = in->getDimensions();
nvinfer1::Dims transposed_input_dims;
transposed_input_dims.nbDims = input_dims.nbDims;
for (int i = input_dims.nbDims - 1; i >= 0; i--) {
transposed_input_dims.d[i] = input_dims.d[input_dims.nbDims - 1 - i];
// For input tensors < 2D, return them as is
// For a 2D input tensor, return transpose(input, 0, 1) which is a general 2d matrix transpose.
if (input_dims.nbDims < 2) {
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], in);
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
return true;
}

auto shuffle_layer = ctx->net->addShuffle(*in);
TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
shuffle_layer->setReshapeDimensions(transposed_input_dims);
shuffle_layer->setZeroIsPlaceholder(true);
nvinfer1::Permutation firstPerm;
firstPerm.order[0] = 1;
firstPerm.order[1] = 0;

shuffle_layer->setFirstTranspose(firstPerm);
shuffle_layer->setZeroIsPlaceholder(false);
shuffle_layer->setName(util::node_info(n).c_str());

auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_layer->getOutput(0));
Expand Down
30 changes: 0 additions & 30 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,36 +427,6 @@ auto aten_registrations TRTORCH_UNUSED =
EvalOptions().validSchemas({
"aten::numel(Tensor self) -> int",
})})
// .evaluator({c10::Symbol::fromQualString("aten::t"),
// [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
// auto tensor_var = args.at(n->input(0));
// if (tensor_var.isIValue() && tensor_var.IValue()->isTensor()) {
// auto tensor = tensor_var.unwrapToTensor();
// return tensor.t();
// } else if (tensor_var.isITensor()) {
// auto input_tensor = tensor_var.ITensor();
// auto input_dims = input_tensor->getDimensions();
// LOG_DEBUG("[aten::t] INPUT TENSOR DIMS: " << input_dims);
// // nvinfer1::Dims transposed_input_dims;
// // for (int i = input_dims.nbDims - 1; i >= 0; i--) {
// // transposed_input_dims.d[i] = input_dims.d[input_dims.nbDims - 1 - i];
// // }
// // auto shuffle_layer = ctx->net->addShuffle(*input_tensor);
// // shuffle_layer->setReshapeDimensions(transposed_input_dims);
// // shuffle_layer->setZeroIsPlaceholder(true);
// // auto output_tensor = shuffle_layer->getOutput(0);
// auto tensor_holder = TensorContainer();
// tensor_holder.hold_tensor(input_tensor);
// auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
// return ival;
// } else {
// TRTORCH_THROW_ERROR("Unimplemented data type for aten::t evaluator: ITensor");
// return {};
// }
// },
// EvalOptions().validSchemas({
// "aten::t(Tensor self) -> Tensor",
// })})
.evaluator({c10::Symbol::fromQualString("aten::dim"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto tensor_var = args.at(n->input(0));
Expand Down
29 changes: 20 additions & 9 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ void LowerBlock(torch::jit::Block* b) {
DropUnusedNodes(b);
}

void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, bool disable_cse) {
passes::UnpackHardSwish(g);
torch::jit::EliminateRedundantGuards(g);
torch::jit::RemoveListMutation(g);
Expand All @@ -42,9 +42,9 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
passes::Conv3DToConvolution(g);
passes::FuseAddMMBranches(g);
passes::RemoveBNDimCheck(g);
LOG_INFO("====PRE CSE =====" << *g);
// torch::jit::EliminateCommonSubexpression(g);
LOG_INFO("====POST CSE =====" << *g);
if (!disable_cse) {
torch::jit::EliminateCommonSubexpression(g);
}
// torch::jit::UnrollLoops(g);
passes::UnpackAddMM(g);
// passes::UnpackBatchNorm(g);
Expand All @@ -57,25 +57,36 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
}

torch::jit::Module LowerModule(const torch::jit::script::Module& mod) {
LOG_DEBUG("Input module is being frozen by torch::jit::freeze_module");
auto mod_ = torch::jit::freeze_module(mod);
return mod_;
}

std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower(
const torch::jit::script::Module& mod,
std::string method_name) {
auto lowered_mod = mod; // LowerModule(mod);
std::string method_name,
bool unfreeze_module = false) {
auto lowered_mod = unfreeze_module ? mod : LowerModule(mod);
auto g = lowered_mod.get_method(method_name).graph();
LOG_GRAPH(*g);

// Go through TRTorch Lowering to reformat graph to be conversion friendly
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
LOG_GRAPH("TRTorch Graph Lowering");
// lowering::LowerGraph(g);
// unfreeze_module is used to not perform constant folding on weights in the network.
// In quantization aware trained (QAT) models, weights are passed through quantize and
// dequantize nodes which should not be folded. So unfreeze_module is set to True for QAT models.
if (!unfreeze_module) {
LOG_GRAPH("TRTorch Graph Lowering");
lowering::LowerGraph(g, false);
}

LOG_GRAPH("LibTorch Lowering");
auto graph_and_ivalues = torch::jit::LowerGraph(*g, lowered_mod._ivalue());
lowering::LowerGraph(graph_and_ivalues.first);

if (unfreeze_module) {
LOG_GRAPH("TRTorch Graph Lowering");
lowering::LowerGraph(graph_and_ivalues.first, true);
}
// Is this necessary?
lowering::LowerBlock(g->block());

Expand Down
5 changes: 3 additions & 2 deletions core/lowering/lowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ namespace core {
namespace lowering {

void LowerBlock(torch::jit::Block* b);
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g);
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, bool disable_cse /*=false*/);
torch::jit::Module LowerModule(const torch::jit::script::Module& mod);
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower(
const torch::jit::script::Module& mod,
std::string method_name);
std::string method_name,
bool unfreeze_module /*=false*/);

} // namespace lowering
} // namespace core
Expand Down
12 changes: 6 additions & 6 deletions cpp/api/include/trtorch/trtorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,9 @@ struct TRTORCH_API CompileSpec {
* Emum for selecting engine capability
*/
enum class EngineCapability : int8_t {
kDEFAULT,
kSAFE_GPU,
kSAFE_DLA,
kSTANDARD,
kSAFETY,
kDLA_STANDALONE,
};

class TRTORCH_API TensorFormat {
Expand Down Expand Up @@ -686,12 +686,12 @@ struct TRTORCH_API CompileSpec {
* This is the behavior of FP32 layers by default.
*/
bool disable_tf32 = false;
/**

/**
* Enable sparsity for weights of conv and FC layers
*/
bool sparse_weights = false;

/**
* Build a refitable engine
*/
Expand Down
8 changes: 7 additions & 1 deletion cpp/api/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,13 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {

if (internal.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) !=
internal.convert_info.engine_settings.enabled_precisions.end()) {
internal.convert_info.engine_settings.calibrator = external.ptq_calibrator;
if (external.ptq_calibrator) {
internal.convert_info.engine_settings.calibrator = external.ptq_calibrator;
} else {
;
internal.convert_info.engine_settings.unfreeze_module = true;
internal.convert_info.engine_settings.calibrator = nullptr;
}
} else {
internal.convert_info.engine_settings.calibrator = nullptr;
}
Expand Down
2 changes: 1 addition & 1 deletion py/trtorch/csrc/register_tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void RegisterTRTCompileSpec() {
.def("_set_torch_fallback", &trtorch::pyapi::CompileSpec::setTorchFallbackIntrusive)
.def("_set_ptq_calibrator", &trtorch::pyapi::CompileSpec::setPTQCalibratorViaHandle)
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);

ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, sparse_weights);
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, disable_tf32);
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, refit);
Expand Down
2 changes: 1 addition & 1 deletion py/trtorch/csrc/tensorrt_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ c10::impl::GenericDict TensorRTBackend::compile(c10::IValue mod_val, c10::impl::
const auto& method_name = it->key();
auto method = mod.get_method(method_name);
auto graph = method.graph();
core::lowering::LowerGraph(graph);
core::lowering::LowerGraph(graph, false);
}

auto handles = c10::impl::GenericDict(
Expand Down
11 changes: 9 additions & 2 deletions py/trtorch/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,15 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
for (auto p : enabled_precisions) {
info.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
}

info.convert_info.engine_settings.calibrator = ptq_calibrator;
if (ptq_calibrator) {
info.convert_info.engine_settings.calibrator = ptq_calibrator;
} else {
if (info.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) !=
info.convert_info.engine_settings.enabled_precisions.end()) {
std::cout << "===INTERNAL UNFREEZE MODULE TRUE===" << std::endl;
info.convert_info.engine_settings.unfreeze_module = true;
}
}
info.convert_info.engine_settings.sparse_weights = sparse_weights;
info.convert_info.engine_settings.disable_tf32 = disable_tf32;
info.convert_info.engine_settings.refit = refit;
Expand Down
25 changes: 24 additions & 1 deletion tests/core/conversion/converters/test_shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,30 @@ TEST(Converters, ATenTransposeConvertsCorrectly) {
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenTConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%out : Tensor = aten::t(%x.1)
return (%out))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(0, 5, {3, 4}, {at::kCUDA});
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});

std::cout << "Running JIT" << std::endl;
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

std::cout << "Running TRT" << std::endl;
in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
auto trt = trt_results[0].reshape_as(jit_results[0]);

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenTransposeNegativeConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
Expand Down Expand Up @@ -312,7 +336,6 @@ TEST(Converters, ATenPixelShuffle3DConvertsCorrectly) {
in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
// auto trt = trt_results[0].reshape_as(jit_results[0]);

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}
Expand Down
8 changes: 0 additions & 8 deletions tests/modules/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,10 @@
"model": torch.hub.load('pytorch/vision:v0.9.0', 'resnet50', pretrained=True),
"path": "both"
},
"fcn_resnet101": {
"model": torch.hub.load('pytorch/vision:v0.9.0', 'fcn_resnet101', pretrained=True),
"path": "script"
},
"ssd": {
"model": torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math="fp32"),
"path": "trace"
},
"faster_rcnn": {
"model": models.detection.fasterrcnn_resnet50_fpn(pretrained=True),
"path": "script"
},
"efficientnet_b0": {
"model": timm.create_model('efficientnet_b0', pretrained=True),
"path": "script"
Expand Down
2 changes: 1 addition & 1 deletion tests/util/run_graph_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ std::vector<at::Tensor> RunGraphEngine(
auto in = toInputs(inputs);
auto info = core::conversion::ConversionInfo(in);
info.engine_settings.workspace_size = 1 << 20;
info.engine_settings.op_precision = op_precision;
info.engine_settings.enabled_precisions.insert(op_precision);
std::string eng = core::conversion::ConvertBlockToEngine(g->block(), info, named_params);
return RunEngine(eng, inputs);
}
Expand Down

0 comments on commit 521a0cb

Please sign in to comment.