Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve batch_norm fp16 accuracy #1450

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions core/conversion/converters/impl/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,28 @@ void _batch_norm(
const torch::Tensor& mean,
const torch::Tensor& var,
const float eps) {
auto scale = gamma / torch::sqrt(var + eps);
auto bias = beta - mean * scale;
auto orig_dtype = var.dtype();
// perform compile-time weight calculations in float to improve accuracy
// resulting weights will be embedded as the original dtype
auto calculation_gamma = gamma;
auto calculation_beta = beta;
auto calculation_mean = mean;
auto calculation_var = var;
if (orig_dtype == torch::kHalf) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick question is this different than the normal pytorch behavior? If so can we add a debug message here saying that we are doing this to improve accuracy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the cudnn implementation at least asserts that the weight is fp32 which would force similar calculations to fp32:
https://github.com/pytorch/pytorch/blob/4bfe2a24505049fa4fe43d24c2e3a5f5d99d9f00/aten/src/ATen/native/cudnn/BatchNorm.cpp#L110

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

calculation_gamma = calculation_gamma.to(torch::kFloat);
calculation_beta = calculation_beta.to(torch::kFloat);
calculation_mean = calculation_mean.to(torch::kFloat);
calculation_var = calculation_var.to(torch::kFloat);
}
auto scale = calculation_gamma / torch::sqrt(calculation_var + eps);
auto bias = calculation_beta - calculation_mean * scale;
LOG_DEBUG("_batch_norm Tensor Scale : " << scale.sizes());
LOG_DEBUG("_batch_norm Tensor bias : " << bias.sizes());

auto scale_weights = Weights(ctx, scale);
auto bias_weights = Weights(ctx, bias);
auto scale_weights = Weights(ctx, scale.to(orig_dtype));
auto bias_weights = Weights(ctx, bias.to(orig_dtype));

auto power = Weights(ctx, at::ones_like(scale));
auto power = Weights(ctx, at::ones_like(scale).to(orig_dtype));
auto bn =
ctx->net->addScaleNd(*input, nvinfer1::ScaleMode::kCHANNEL, bias_weights.data, scale_weights.data, power.data, 1);
bn->setName(util::node_info(n).c_str());
Expand Down
30 changes: 30 additions & 0 deletions tests/core/conversion/converters/test_batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,33 @@ TEST(Converters, ATenBatchNormShouldUnpackConvertsCorrectly) {
ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenBatchNormHalfConvertsCorrectly) {
const auto graph = R"IR(
graph(%input : Tensor, %running_var : Half(32, strides=[1], requires_grad=0, device=cuda:0), %running_mean : Half(32, strides=[1], requires_grad=0, device=cuda:0)):
%5 : bool = prim::Constant[value=0]()
%4 : float = prim::Constant[value=0.01]()
%3 : float = prim::Constant[value=0.001]()
%2 : bool = prim::Constant[value=1]()
%8 : Tensor = aten::batch_norm(%input, %running_var, %running_mean, %running_mean, %running_var, %5, %4, %3, %2)
return (%8))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randn({2, 32, 5, 5}, {at::kCUDA}).to(at::kHalf);
auto mean = at::ones({32}, {at::kCUDA}).to(at::kHalf);
auto var = at::zeros({32}, {at::kCUDA}).to(at::kHalf);

auto trt_in = at::clone(in);
auto trt_mean = at::clone(mean);
auto trt_var = at::clone(var);

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {mean, var});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_mean, trt_var});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}, {nvinfer1::DataType::kHALF});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-2));
}