Skip to content

Commit

Permalink
feat: Add handling for ITensor mean and var in batch_norm (#3099)
Browse files Browse the repository at this point in the history
  • Loading branch information
chohk88 authored Aug 22, 2024
1 parent 9a08cc7 commit 66511da
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 41 deletions.
110 changes: 75 additions & 35 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
from torch_tensorrt.dynamo.conversion.impl.cat import cat
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
from torch_tensorrt.dynamo.types import TRTTensor
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.fx.types import TRTTensor
from torch_tensorrt.fx.utils import get_dynamic_dims

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand All @@ -34,61 +33,102 @@ def batch_norm(
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
weight: Optional[Union[torch.Tensor, np.ndarray]],
bias: Optional[Union[torch.Tensor, np.ndarray]],
running_mean: Optional[Union[torch.Tensor, np.ndarray]],
running_var: Optional[Union[torch.Tensor, np.ndarray]],
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
running_mean: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
running_var: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
training: bool,
momentum: float,
eps: float,
cudnn_enabled: bool,
return_mean_rstd: bool,
) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]:

if has_dynamic_shape(input.shape):
assert input.shape[1] != -1, "Channel dim can't be dynamic for batch norm."

if weight is None:
weight = 1.0
# Save the original output shape for later use
output_shape = input.shape

if weight is None:
weight = get_trt_tensor(ctx, 1.0, f"{name}_weight")
if bias is None:
bias = 0.0

bias = get_trt_tensor(ctx, 0.0, f"{name}_bias")
if running_mean is None:
running_mean = 0.0

running_mean = get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
if running_var is None:
running_var = 1.0
running_var = get_trt_tensor(ctx, 1.0, f"{name}_running_var")

scale = to_numpy(weight) / np.sqrt(to_numpy(running_var) + eps)
bias = to_numpy(bias) - to_numpy(running_mean) * scale
power = np.ones_like(scale)
# eps_tensor for numerical stability
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps")

# For BatchNorm1d, reshape 1d to 2d
output_shape = input.shape
if len(input.shape) < 4:
assert (
len(get_dynamic_dims(input.shape)) <= 1
), "BatchNorm1D with more than one dynamic dims is not currently supported."
new_shape = (
(input.shape[0], input.shape[1], 1, 1)
if len(input.shape) == 2
else (input.shape[0], input.shape[1], input.shape[2], 1)
)
input = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_2d", input, new_shape
)
layer = ctx.net.add_scale(input, trt.ScaleMode.CHANNEL, bias, scale, power)
set_layer_name(layer, target, name, source_ir)
output = layer.get_output(0)
# adjusted_var = running_var + eps
adjusted_var = impl.elementwise.add(
ctx, target, source_ir, f"{name}_adjusted_var", running_var, eps_tensor
)

# sqrt_adjusted_var = sqrt(adjusted_var)
sqrt_adjusted_var = impl.unary.sqrt(
ctx, target, source_ir, f"{name}_sqrt", adjusted_var
)

# scale = weight / sqrt_adjusted_var
scale = impl.elementwise.div(
ctx, target, source_ir, f"{name}_scale", weight, sqrt_adjusted_var
)

# scaled_running_mean = running_mean * scale
scaled_running_mean = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_scaled_running_mean", running_mean, scale
)

# bias_adjusted = bias - scaled_running_mean
bias_adjusted = impl.elementwise.sub(
ctx, target, source_ir, f"{name}_bias_adjusted", bias, scaled_running_mean
)

# Reshape scale and bias_adjusted to match input shape for broadcasting
expanded_shape = [1] * len(output_shape)
expanded_shape[1] = output_shape[1] # Set channel dimension

scale_reshape = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_scale",
scale,
tuple(expanded_shape),
)
bias_adjusted_reshape = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_bias",
bias_adjusted,
tuple(expanded_shape),
)

# Apply the scale and bias to the input
scaled_input = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_scaled_input", input, scale_reshape
)
output = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_output",
scaled_input,
bias_adjusted_reshape,
)

# For BatchNorm1d, reshape output back to 1d
# For BatchNorm1d, reshape output back to original shape if necessary
if len(output_shape) < 4:
output = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_1d",
layer.get_output(0),
output,
output_shape,
)

Expand Down
105 changes: 99 additions & 6 deletions tests/py/dynamo/conversion/test_batch_norm_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,82 @@


class TestBatchNormConverter(DispatchTestCase):
def test_batchnorm(self):
def test_batchnorm_static_weights(self):
class BatchNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.batch_norm.default(
x,
torch.full((FEATURE_NUM,), 3, dtype=torch.float32),
torch.zeros((FEATURE_NUM,)),
torch.zeros((FEATURE_NUM,)),
torch.full((FEATURE_NUM,), 3, dtype=torch.float32),
False,
0.1,
1e-05,
True,
)

inputs = [torch.randn(1, 3, 224, 224)]
self.run_test(
BatchNorm(),
inputs,
)

def test_batchnorm_ITensor_weights_bias(self):
class BatchNorm(torch.nn.Module):
def forward(self, x, weight, bias):
return torch.ops.aten.batch_norm.default(
x,
weight,
bias,
torch.zeros((FEATURE_NUM,)),
torch.ones((FEATURE_NUM,)),
False,
0.1,
1e-05,
True,
)

inputs = [
torch.randn(1, 3, 224, 224),
torch.ones((FEATURE_NUM,)),
torch.zeros((FEATURE_NUM,)),
]
self.run_test(
BatchNorm(),
inputs,
)

def test_batchnorm_ITensor_weights(self):
class BatchNorm(torch.nn.Module):
def forward(self, x, weight):
return torch.ops.aten.batch_norm.default(
x,
weight,
None,
torch.zeros((FEATURE_NUM,)),
torch.ones((FEATURE_NUM,)),
False,
0.1,
1e-05,
True,
)

inputs = [
torch.randn(1, 3, 224, 224),
torch.ones((FEATURE_NUM,)),
]
self.run_test(
BatchNorm(),
inputs,
)

def test_batchnorm_static_bias_only(self):
class BatchNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.batch_norm.default(
x,
None,
torch.zeros((FEATURE_NUM,)),
torch.zeros((FEATURE_NUM,)),
torch.ones((FEATURE_NUM,)),
Expand Down Expand Up @@ -57,7 +127,7 @@ def forward(self, x):
input_specs,
)

def test_batchnorm_with_dynamic_shape(self):
def test_batchnorm2d_with_dynamic_shape(self):
class BatchNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.batch_norm.default(
Expand Down Expand Up @@ -87,7 +157,7 @@ def forward(self, x):


class TestNativeBatchNormConverter(DispatchTestCase):
def test_batchnorm(self):
def test_native_batchnorm_static_weights(self):
class BatchNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.native_batch_norm.default(
Expand All @@ -107,7 +177,30 @@ def forward(self, x):
inputs,
)

def test_batchnorm_legit_no_training(self):
def test_native_batchnorm_legit_no_training_with_trt_tensor(self):
class BatchNorm(torch.nn.Module):
def forward(self, x, running_mean, running_var):
return torch.ops.aten._native_batch_norm_legit_no_training.default(
x,
torch.ones((FEATURE_NUM,)),
torch.zeros((FEATURE_NUM,)),
running_mean,
running_var,
0.1,
1e-05,
)[0]

inputs = [
torch.randn(1, 3, 224, 224),
torch.zeros((FEATURE_NUM,)),
torch.ones((FEATURE_NUM,)),
]
self.run_test(
BatchNorm(),
inputs,
)

def test_native_batchnorm_legit_no_training_with_static_means(self):
class BatchNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten._native_batch_norm_legit_no_training.default(
Expand All @@ -126,7 +219,7 @@ def forward(self, x):
inputs,
)

def test_batchnorm1d_with_dynamic_shape(self):
def test_native_batchnorm1d_with_dynamic_shape(self):
class BatchNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.native_batch_norm.default(
Expand All @@ -153,7 +246,7 @@ def forward(self, x):
input_specs,
)

def test_batchnorm_with_dynamic_shape(self):
def test_native_batchnorm2d_with_dynamic_shape(self):
class BatchNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.native_batch_norm.default(
Expand Down

0 comments on commit 66511da

Please sign in to comment.