Skip to content

Commit

Permalink
fix(aten::batch_norm): A new batch norm implementation that hopefully
Browse files Browse the repository at this point in the history
doesnt have the same performace cost

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed May 8, 2020
1 parent 2d677cd commit 6461872
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 87 deletions.
141 changes: 54 additions & 87 deletions core/conversion/converters/impl/batch_norm.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "torch/torch.h"
#include "core/util/prelude.h"
#include "core/conversion/converters/converters.h"

Expand All @@ -8,93 +9,59 @@ namespace converters {
namespace impl {
namespace {

bool ConvertConvBatchNorm(ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
auto input = args[0].ITensor();
auto shape = util::toVec(input->getDimensions());
LOG_WARNING("Assuming channel dimension is 3 because input is from a conv layer, please verify");
auto gamma = args[1].unwrapToTensor(at::full({shape[shape.size() - 3]}, 1));
auto beta = args[2].unwrapToTensor(at::full({shape[shape.size() - 3]}, 1));
auto mean = args[3].unwrapToTensor(at::full({shape[shape.size() - 3]}, 0));
auto var = args[4].unwrapToTensor(at::full({shape[shape.size() - 3]}, 0));
LOG_WARNING("Momentum argument is disregarded");
//auto momentum = args[6].unwrapToDouble(0);
auto eps = args[7].unwrapToDouble(0);

auto w = at::diag(gamma / at::sqrt(var + eps));
auto w_shape = w.sizes().vec();
w_shape.push_back(1);
w_shape.push_back(1);
w = w.reshape(w_shape);
auto b = beta - gamma * (mean / at::sqrt(var + eps));

auto weights = Weights(ctx, w);
auto bias = Weights(ctx, b);

auto bn_as_conv = ctx->net->addConvolutionNd(*input, weights.num_output_maps, weights.kernel_shape, weights.data, bias.data);
TRTORCH_CHECK(bn_as_conv, "Unable to create fused batch norm from node: " << *n);

bn_as_conv->setName(util::node_info(n).c_str());

auto bn_out = ctx->AssociateValueAndTensor(n->outputs()[0], bn_as_conv->getOutput(0));
LOG_DEBUG("Output tensor shape: " << bn_out->getDimensions());
return true;
}

bool ConvertLinearBatchNorm(ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
auto input = args[0].ITensor();
auto shape = util::toVec(input->getDimensions());
auto gamma = args[1].unwrapToTensor(at::full({shape},1));
auto beta = args[2].unwrapToTensor(at::full({shape},1));
auto mean = args[3].unwrapToTensor(at::full({shape},0));
auto var = args[4].unwrapToTensor(at::full({shape},0));
LOG_WARNING("Momentum argument is disregarded");
//auto momentum = args[6].unwrapToDouble(0);
auto eps = args[7].unwrapToDouble(0);

auto mean_ = tensor_to_const(ctx, mean);
auto bot_half = at::sqrt(var + eps);
auto bot_half_ = tensor_to_const(ctx, bot_half);
auto gamma_ = tensor_to_const(ctx, gamma);
auto beta_ = tensor_to_const(ctx, beta);

auto top_half = ctx->net->addElementWise(*input, *mean_, nvinfer1::ElementWiseOperation::kSUB);
auto top_half_out = top_half->getOutput(0);
auto x_hat = ctx->net->addElementWise(*top_half_out, *bot_half_, nvinfer1::ElementWiseOperation::kDIV);
auto x_hat_out = x_hat->getOutput(0);
auto bn_scaled = ctx->net->addElementWise(*gamma_, *x_hat_out, nvinfer1::ElementWiseOperation::kPROD);
auto bn_scaled_out = bn_scaled->getOutput(0);
auto bn_biased = ctx->net->addElementWise(*beta_, *bn_scaled_out, nvinfer1::ElementWiseOperation::kSUM);
auto bn_biased_out = bn_biased->getOutput(0);

bn_biased->setName(util::node_info(n).c_str());
ctx->AssociateValueAndTensor(n->outputs()[0], bn_biased_out);

return true;
}

volatile auto batch_norm_registrations = RegisterNodeConversionPatterns()
.pattern({
R"SIG(aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta,
Tensor? mean, Tensor? var,
bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor))SIG",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto input = args[0].ITensor();
auto shape = input->getDimensions();
auto gamma = args[1].unwrapToTensor();

if (/*training*/ args[5].unwrapToBool()) {
LOG_WARNING(R"WARN(TRTorch only converts forward pass of graphs, but saw training = True, may see
unexpected behavior, consider placing module in eval mode before exporting the TorchScript module)WARN");
}

// If gamma is None this fails
if (util::volume(shape) == gamma.numel()) {
return ConvertLinearBatchNorm(ctx, n, args);
} else {
return ConvertConvBatchNorm(ctx, n, args);
}
}
});
auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
R"SIG(aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta,
Tensor? mean, Tensor? var,
bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor))SIG",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto input = args[0].ITensor();
auto orig_shape = input->getDimensions();
auto shape = util::toVec(orig_shape);
auto options = torch::TensorOptions().dtype(torch::kFloat32);
auto gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
auto beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
auto mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
auto var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
auto eps = args[7].unwrapToDouble(1e-5f);

LOG_DEBUG("momentum disregarded");
LOG_DEBUG("training disregarded");
LOG_DEBUG("cudnn disregarded");

auto should_unpack = util::toVec(orig_shape).size() < 4;
if (should_unpack) {
// expand spatial dims from 1D to 2D
auto new_shape = util::toDimsPad(util::toVec(orig_shape), 4);
LOG_DEBUG("Input shape is less than 4D got: " << orig_shape << ", inserting shuffle layer to reshape to 4D tensor shape: " << new_shape);
auto in_shuffle = ctx->net->addShuffle(*input);
in_shuffle->setReshapeDimensions(new_shape);
in_shuffle->setName(std::string("[Reshape input to " + util::toStr(new_shape) + ']').c_str());
input = in_shuffle->getOutput(0);
}

auto scale = gamma / torch::sqrt(var + eps);
auto bias = beta - mean * scale;

auto scale_weights = Weights(ctx, scale);
auto bias_weights = Weights(ctx, bias);

auto bn = ctx->net->addScaleNd(*input, nvinfer1::ScaleMode::kCHANNEL, bias_weights.data, scale_weights.data, {}, 1);
bn->setName(util::node_info(n).c_str());
auto out_tensor = bn->getOutput(0);

if (should_unpack) {
LOG_DEBUG("Inserting shuffle layer to reshape to back to original shape: " << orig_shape);
auto out_shuffle = ctx->net->addShuffle(*out_tensor);
out_shuffle->setReshapeDimensions(orig_shape);
out_shuffle->setName(std::string("[Reshape output to " + util::toStr(orig_shape) + ']').c_str());
out_tensor = out_shuffle->getOutput(0);
}

ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
return true;
}
});


} // namespace
Expand Down
1 change: 1 addition & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
passes::FuseFlattenLinear(g);
passes::Conv2DToConvolution(g);
passes::UnpackAddMM(g);
//passes::UnpackBatchNorm(g);
passes::UnpackLogSoftmax(g);
//passes::RemoveDimExeception(g);
//irfusers::UnpackBatchNorm(g);
Expand Down
3 changes: 3 additions & 0 deletions core/lowering/passes/unpack_batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph) {
torch::jit::SubgraphRewriter unpack_batch_norm;
unpack_batch_norm.RegisterRewritePattern(batch_norm_pattern, expanded_batch_norm_pattern);
unpack_batch_norm.runOnGraph(graph);
LOG_DEBUG("[Lowering Batch Norm]: momentum disregarded");
LOG_DEBUG("[Lowering Batch Norm]: training disregarded");
LOG_DEBUG("[Lowering Batch Norm]: cudnn disregarded");
LOG_GRAPH("Post unpack batchnorm: " << *graph);
}
} // Namespace passes
Expand Down
5 changes: 5 additions & 0 deletions tests/core/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ converter_test(
name = "test_activation"
)

converter_test(
name = "test_batch_norm"
)

converter_test(
name = "test_conv"
)
Expand Down Expand Up @@ -44,6 +48,7 @@ test_suite(
name = "test_converters",
tests = [
":test_activation",
":test_batch_norm",
":test_conv",
":test_element_wise",
":test_linear",
Expand Down
36 changes: 36 additions & 0 deletions tests/core/converters/test_batch_norm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include <string>
#include "gtest/gtest.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "tests/util/util.h"
#include "core/compiler.h"

TEST(Converters, ATenBatchNormConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1: Float(5),
%2: Float(5),
%3: Float(5),
%4: Float(5)):
%5 : bool = prim::Constant[value=0]()
%6 : float = prim::Constant[value=1.0000000000000001e-05]()
%7 : float = prim::Constant[value=0.10000000000000001]()
%8 : Tensor = aten::batch_norm(%0, %1, %2, %3, %4, %5, %6, %7, %5)
return (%8))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
auto gamma = at::randint(1, 10, {5}, {at::kCUDA});
auto beta = at::randint(1, 10, {5}, {at::kCUDA});
auto mean = at::randint(1, 10, {5}, {at::kCUDA});
auto var = at::randint(1, 10, {5}, {at::kCUDA});

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {gamma, beta, mean, var});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});

params = trtorch::core::conversion::get_named_params(g->inputs(), {gamma, beta, mean, var});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

0 comments on commit 6461872

Please sign in to comment.