Skip to content

Commit

Permalink
feat: Add functionality for QAT workflow
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 committed Jul 7, 2021
1 parent 54f08f9 commit fc8eafb
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 38 deletions.
8 changes: 4 additions & 4 deletions core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
}
input_type = nvinfer1::DataType::kFLOAT;
TRTORCH_CHECK(
settings.calibrator != nullptr,
"Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the CompileSpec struct with your calibrator");
cfg->setInt8Calibrator(settings.calibrator);
// TRTORCH_CHECK(
// settings.calibrator != nullptr,
// "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the CompileSpec struct with your calibrator");
// cfg->setInt8Calibrator(settings.calibrator);
break;
case nvinfer1::DataType::kFLOAT:
default:
Expand Down
1 change: 1 addition & 0 deletions core/conversion/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ cc_library(
"impl/matrix_multiply.cpp",
"impl/normalize.cpp",
"impl/pooling.cpp",
"impl/quantization.cpp",
"impl/reduce.cpp",
"impl/replication_pad.cpp",
"impl/select.cpp",
Expand Down
119 changes: 88 additions & 31 deletions core/conversion/converters/impl/conv_deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,95 @@ namespace impl {
namespace {

bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
auto in = args[0].ITensor(); // assumes non-static input Tensor
auto w = Weights(ctx, args[1].unwrapToTensor());
// Input to conv/deconv
auto in = args[0].ITensor();

// Conv /deconv parameters
auto stride = util::toDims(args[3].unwrapToIntList());
auto padding = util::toDims(args[4].unwrapToIntList());
auto dilation = util::toDims(args[5].unwrapToIntList());
bool transposed = args[6].unwrapToBool();
auto out_padding = util::toDims(args[7].unwrapToIntList());
int64_t groups = args[8].unwrapToInt();

// Reshape the parameters to 2D if needed
if (stride.nbDims == 1) {
stride = util::unsqueezeDims(stride, 1, 1);
LOG_DEBUG("Reshaped stride: " << stride);
}
if (dilation.nbDims == 1) {
dilation = util::unsqueezeDims(dilation, 1, 1);
LOG_DEBUG("Reshaped dilation: " << dilation);
}
if (padding.nbDims == 1) {
padding = util::unsqueezeDims(padding, 1, 0);
LOG_DEBUG("Reshaped padding: " << padding);
}
if (out_padding.nbDims == 1) {
out_padding = util::unsqueezeDims(out_padding, 1, 0);
LOG_DEBUG("Reshaped out_padding: " << out_padding);
}

// Get bias tensor or initialize it to zeros.
Weights bias;
if (args[2].IValue()->isTensor()) {
bias = Weights(ctx, args[2].unwrapToTensor());
} else {
bias = Weights(); //nvinfer1::Weights{nvinfer1::DataType::kFLOAT, nullptr, 0};
}

// Handle case when weights of conv/deconv is an ITensor. This case happens for QAT networks where
// conv_weights -> Quantize -> Dequantize -> new_conv_weights -> conv <- input
// new_conv_weights will be an ITensor because it is an output of Dequantize layer defined in impl/quantization.cpp
if (args[1].isITensor()){
// Get the kernel tensor
auto kernel = args[1].ITensor();
auto kernel_dims = kernel->getDimensions();

// Make a new Dims with only the spatial dimensions.
nvinfer1::Dims filter_dim;
int64_t nbSpatialDims = in->getDimensions().nbDims - 2;
TRTORCH_CHECK(nbSpatialDims = kernel_dims.nbDims - 2, "Number of input spatial dimensions should match the kernel spatial dimensions");
filter_dim.nbDims = nbSpatialDims;
filter_dim.d[0] = kernel_dims.d[2];
filter_dim.d[1] = kernel_dims.d[3];

// Initialize a dummy constant kernel to pass it to INetwork->addConvolutionNd/addDeconvolutionNd API.
auto kernel_weights = nvinfer1::Weights{nvinfer1::DataType::kFLOAT, nullptr, 0};

nvinfer1::ILayer* layer = nullptr;
if (transposed){
nvinfer1::IDeconvolutionLayer* deconvLayer
= ctx->net->addDeconvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data);
deconvLayer->setStrideNd(stride);
deconvLayer->setDilationNd(dilation);
deconvLayer->setNbGroups(groups);
deconvLayer->setPaddingNd(padding);
// Set deconv kernel weights
deconvLayer->setInput(1, *kernel);
TRTORCH_CHECK(deconvLayer, "Unable to create deconv layer with non-const weights from node: " << *n);
layer = deconvLayer;
} else{
nvinfer1::IConvolutionLayer* convLayer
= ctx->net->addConvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data);
convLayer->setStrideNd(stride);
convLayer->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN);
convLayer->setPaddingNd(padding);
convLayer->setPostPadding(out_padding);
convLayer->setDilationNd(dilation);
convLayer->setNbGroups(groups);

// Set conv kernel weights
convLayer->setInput(1, *kernel);
layer = convLayer;
}

ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0));
LOG_DEBUG("Output tensor shape: " << layer->getOutput(0)->getDimensions());
return true;
}

auto w = Weights(ctx, args[1].unwrapToTensor());
auto dims = in->getDimensions();
auto orig_dims = dims;
LOG_DEBUG("Input dims: " << orig_dims);
Expand All @@ -46,32 +126,9 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
w.kernel_shape.d[1] = 1;
LOG_DEBUG("Reshaped Weights: " << w);
}
if (stride.nbDims == 1) {
stride = util::unsqueezeDims(stride, 1, 1);
LOG_DEBUG("Reshaped stride: " << stride);
}
if (dilation.nbDims == 1) {
dilation = util::unsqueezeDims(dilation, 1, 1);
LOG_DEBUG("Reshaped dilation: " << dilation);
}
if (padding.nbDims == 1) {
padding = util::unsqueezeDims(padding, 1, 0);
LOG_DEBUG("Reshaped padding: " << padding);
}
if (out_padding.nbDims == 1) {
out_padding = util::unsqueezeDims(out_padding, 1, 0);
LOG_DEBUG("Reshaped out_padding: " << out_padding);
}

nvinfer1::ILayer* new_layer;
if (transposed) {
Weights bias;
if (args[2].IValue()->isTensor()) {
bias = Weights(ctx, args[2].unwrapToTensor());
} else {
bias = Weights(ctx, torch::zeros(w.shape.d[1] * groups));
}

// shape of deconvolution's weight: [in, out/groups, ...]
auto deconv = ctx->net->addDeconvolutionNd(*in, w.shape.d[1] * groups, w.kernel_shape, w.data, bias.data);
TRTORCH_CHECK(deconv, "Unable to create deconvolution layer from node: " << *n);
Expand All @@ -89,12 +146,12 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
#endif
new_layer = deconv;
} else {
Weights bias;
if (args[2].IValue()->isTensor()) {
bias = Weights(ctx, args[2].unwrapToTensor());
} else {
bias = Weights(ctx, torch::zeros(w.shape.d[0]));
}
// Weights bias;
// if (args[2].IValue()->isTensor()) {
// bias = Weights(ctx, args[2].unwrapToTensor());
// } else {
// bias = Weights(ctx, torch::zeros(w.shape.d[0]));
// }

// shape of convolution's weight: [out, in/groups, ...]
auto conv = ctx->net->addConvolutionNd(*in, w.shape.d[0], w.kernel_shape, w.data, bias.data);
Expand Down
24 changes: 24 additions & 0 deletions core/conversion/converters/impl/linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,30 @@ auto linear_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().patt
in = in_shuffle->getOutput(0);
}

// Get the bias
Weights bias;
if(!args[2].IValue()->isNone()){
bias = Weights(ctx, args[2].IValue()->toTensor());
}else {
bias = Weights();
}

// Handle case when weights of conv/deconv is an ITensor. This case happens for QAT networks where
// conv_weights -> Quantize -> Dequantize -> new_conv_weights -> conv <- input
// new_conv_weights will be an ITensor because it is an output of Dequantize layer defined in impl/quantization.cpp
if(args[1].isITensor()){
auto kernel_tensor = args[1].ITensor();
auto kernel_dims = args[1].ITensor()->getDimensions();
// Initialize a dummy constant kernel to pass it to INetwork->addConvolutionNd/addDeconvolutionNd API.
auto kernel_weights = nvinfer1::Weights{nvinfer1::DataType::kFLOAT, nullptr, 0};
auto fc_layer = ctx->net->addFullyConnected(*in, kernel_dims.d[0], kernel_weights, bias.data);
fc_layer->setInput(1, *kernel_tensor);
fc_layer->setName(util::node_info(n).c_str());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], fc_layer->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
return true;
}

auto w_tensor = args[1].IValue()->toTensor();
Weights w = Weights(ctx, w_tensor);

Expand Down
16 changes: 13 additions & 3 deletions core/lowering/passes/linear_to_addmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ namespace passes {
void replaceLinearWithBiasNonePattern(std::shared_ptr<torch::jit::Graph> graph) {
// Define the decomposition function for aten::linear for the case where bias (mat2) is None.
static torch::jit::CompilationUnit decompose_funcs(R"SCRIPT(
def linear(self: Tensor, mat1: Tensor, mat2: Tensor):
def linear_bias_none(self: Tensor, mat1: Tensor, mat2: Tensor) -> Tensor:
return torch.matmul(self, mat1.t())
def linear(self: Tensor, mat1: Tensor, mat2: Tensor) -> Tensor:
return torch.matmul(self, torch.transpose(mat1, 0, 1)) + mat2
)SCRIPT");

// Iterate through nodes and search for aten::linear nodes where bias is not a Tensor (includes bias=None case)
Expand All @@ -29,16 +32,23 @@ void replaceLinearWithBiasNonePattern(std::shared_ptr<torch::jit::Graph> graph)
auto n = *it;
if (n->kind().toQualString() == std::string("aten::linear")) {
auto input_values = n->inputs();
std::cout << "WEIGHT CONST ?: " << input_values[1]->type()->isSubtypeOf(c10::TensorType::get()) << std::endl;
// input_values[2] is the bias. If none, replace it with the decomposed linear graph.
if (input_values[2]->type()->isSubtypeOf(c10::TensorType::get())) {
continue;
} else {
// continue;
torch::jit::WithInsertPoint guard(*it);
std::shared_ptr<torch::jit::Graph> d_graph = decompose_funcs.get_function("linear").graph();
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0);
new_output->setType(it->output()->type());
it->output()->replaceAllUsesWith(new_output);
it.destroyCurrent();
} else {
torch::jit::WithInsertPoint guard(*it);
std::shared_ptr<torch::jit::Graph> d_graph = decompose_funcs.get_function("linear_bias_none").graph();
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0);
new_output->setType(it->output()->type());
it->output()->replaceAllUsesWith(new_output);
it.destroyCurrent();
}
}
}
Expand Down

0 comments on commit fc8eafb

Please sign in to comment.