diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index c9599c7476..1d498a9930 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -21,9 +21,8 @@ from torch_tensorrt.dynamo.conversion.impl.cat import cat from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape +from torch_tensorrt.dynamo.types import TRTTensor from torch_tensorrt.dynamo.utils import DYNAMIC_DIM -from torch_tensorrt.fx.types import TRTTensor -from torch_tensorrt.fx.utils import get_dynamic_dims _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -34,61 +33,102 @@ def batch_norm( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - weight: Optional[Union[torch.Tensor, np.ndarray]], - bias: Optional[Union[torch.Tensor, np.ndarray]], - running_mean: Optional[Union[torch.Tensor, np.ndarray]], - running_var: Optional[Union[torch.Tensor, np.ndarray]], + weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], + bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], + running_mean: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], + running_var: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], training: bool, momentum: float, eps: float, cudnn_enabled: bool, return_mean_rstd: bool, ) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]: + if has_dynamic_shape(input.shape): assert input.shape[1] != -1, "Channel dim can't be dynamic for batch norm." - if weight is None: - weight = 1.0 + # Save the original output shape for later use + output_shape = input.shape + if weight is None: + weight = get_trt_tensor(ctx, 1.0, f"{name}_weight") if bias is None: - bias = 0.0 - + bias = get_trt_tensor(ctx, 0.0, f"{name}_bias") if running_mean is None: - running_mean = 0.0 - + running_mean = get_trt_tensor(ctx, 0.0, f"{name}_running_mean") if running_var is None: - running_var = 1.0 + running_var = get_trt_tensor(ctx, 1.0, f"{name}_running_var") - scale = to_numpy(weight) / np.sqrt(to_numpy(running_var) + eps) - bias = to_numpy(bias) - to_numpy(running_mean) * scale - power = np.ones_like(scale) + # eps_tensor for numerical stability + eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps") - # For BatchNorm1d, reshape 1d to 2d - output_shape = input.shape - if len(input.shape) < 4: - assert ( - len(get_dynamic_dims(input.shape)) <= 1 - ), "BatchNorm1D with more than one dynamic dims is not currently supported." - new_shape = ( - (input.shape[0], input.shape[1], 1, 1) - if len(input.shape) == 2 - else (input.shape[0], input.shape[1], input.shape[2], 1) - ) - input = impl.shuffle.reshape( - ctx, target, source_ir, f"{name}_reshape_2d", input, new_shape - ) - layer = ctx.net.add_scale(input, trt.ScaleMode.CHANNEL, bias, scale, power) - set_layer_name(layer, target, name, source_ir) - output = layer.get_output(0) + # adjusted_var = running_var + eps + adjusted_var = impl.elementwise.add( + ctx, target, source_ir, f"{name}_adjusted_var", running_var, eps_tensor + ) + + # sqrt_adjusted_var = sqrt(adjusted_var) + sqrt_adjusted_var = impl.unary.sqrt( + ctx, target, source_ir, f"{name}_sqrt", adjusted_var + ) + + # scale = weight / sqrt_adjusted_var + scale = impl.elementwise.div( + ctx, target, source_ir, f"{name}_scale", weight, sqrt_adjusted_var + ) + + # scaled_running_mean = running_mean * scale + scaled_running_mean = impl.elementwise.mul( + ctx, target, source_ir, f"{name}_scaled_running_mean", running_mean, scale + ) + + # bias_adjusted = bias - scaled_running_mean + bias_adjusted = impl.elementwise.sub( + ctx, target, source_ir, f"{name}_bias_adjusted", bias, scaled_running_mean + ) + + # Reshape scale and bias_adjusted to match input shape for broadcasting + expanded_shape = [1] * len(output_shape) + expanded_shape[1] = output_shape[1] # Set channel dimension + + scale_reshape = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_scale", + scale, + tuple(expanded_shape), + ) + bias_adjusted_reshape = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_bias", + bias_adjusted, + tuple(expanded_shape), + ) + + # Apply the scale and bias to the input + scaled_input = impl.elementwise.mul( + ctx, target, source_ir, f"{name}_scaled_input", input, scale_reshape + ) + output = impl.elementwise.add( + ctx, + target, + source_ir, + f"{name}_output", + scaled_input, + bias_adjusted_reshape, + ) - # For BatchNorm1d, reshape output back to 1d + # For BatchNorm1d, reshape output back to original shape if necessary if len(output_shape) < 4: output = impl.shuffle.reshape( ctx, target, source_ir, f"{name}_reshape_1d", - layer.get_output(0), + output, output_shape, ) diff --git a/tests/py/dynamo/conversion/test_batch_norm_aten.py b/tests/py/dynamo/conversion/test_batch_norm_aten.py index bb1e0d8931..7b43b50769 100644 --- a/tests/py/dynamo/conversion/test_batch_norm_aten.py +++ b/tests/py/dynamo/conversion/test_batch_norm_aten.py @@ -8,12 +8,82 @@ class TestBatchNormConverter(DispatchTestCase): - def test_batchnorm(self): + def test_batchnorm_static_weights(self): class BatchNorm(torch.nn.Module): def forward(self, x): return torch.ops.aten.batch_norm.default( x, + torch.full((FEATURE_NUM,), 3, dtype=torch.float32), + torch.zeros((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + torch.full((FEATURE_NUM,), 3, dtype=torch.float32), + False, + 0.1, + 1e-05, + True, + ) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test( + BatchNorm(), + inputs, + ) + + def test_batchnorm_ITensor_weights_bias(self): + class BatchNorm(torch.nn.Module): + def forward(self, x, weight, bias): + return torch.ops.aten.batch_norm.default( + x, + weight, + bias, + torch.zeros((FEATURE_NUM,)), torch.ones((FEATURE_NUM,)), + False, + 0.1, + 1e-05, + True, + ) + + inputs = [ + torch.randn(1, 3, 224, 224), + torch.ones((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + ] + self.run_test( + BatchNorm(), + inputs, + ) + + def test_batchnorm_ITensor_weights(self): + class BatchNorm(torch.nn.Module): + def forward(self, x, weight): + return torch.ops.aten.batch_norm.default( + x, + weight, + None, + torch.zeros((FEATURE_NUM,)), + torch.ones((FEATURE_NUM,)), + False, + 0.1, + 1e-05, + True, + ) + + inputs = [ + torch.randn(1, 3, 224, 224), + torch.ones((FEATURE_NUM,)), + ] + self.run_test( + BatchNorm(), + inputs, + ) + + def test_batchnorm_static_bias_only(self): + class BatchNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.batch_norm.default( + x, + None, torch.zeros((FEATURE_NUM,)), torch.zeros((FEATURE_NUM,)), torch.ones((FEATURE_NUM,)), @@ -57,7 +127,7 @@ def forward(self, x): input_specs, ) - def test_batchnorm_with_dynamic_shape(self): + def test_batchnorm2d_with_dynamic_shape(self): class BatchNorm(torch.nn.Module): def forward(self, x): return torch.ops.aten.batch_norm.default( @@ -87,7 +157,7 @@ def forward(self, x): class TestNativeBatchNormConverter(DispatchTestCase): - def test_batchnorm(self): + def test_native_batchnorm_static_weights(self): class BatchNorm(torch.nn.Module): def forward(self, x): return torch.ops.aten.native_batch_norm.default( @@ -107,7 +177,30 @@ def forward(self, x): inputs, ) - def test_batchnorm_legit_no_training(self): + def test_native_batchnorm_legit_no_training_with_trt_tensor(self): + class BatchNorm(torch.nn.Module): + def forward(self, x, running_mean, running_var): + return torch.ops.aten._native_batch_norm_legit_no_training.default( + x, + torch.ones((FEATURE_NUM,)), + torch.zeros((FEATURE_NUM,)), + running_mean, + running_var, + 0.1, + 1e-05, + )[0] + + inputs = [ + torch.randn(1, 3, 224, 224), + torch.zeros((FEATURE_NUM,)), + torch.ones((FEATURE_NUM,)), + ] + self.run_test( + BatchNorm(), + inputs, + ) + + def test_native_batchnorm_legit_no_training_with_static_means(self): class BatchNorm(torch.nn.Module): def forward(self, x): return torch.ops.aten._native_batch_norm_legit_no_training.default( @@ -126,7 +219,7 @@ def forward(self, x): inputs, ) - def test_batchnorm1d_with_dynamic_shape(self): + def test_native_batchnorm1d_with_dynamic_shape(self): class BatchNorm(torch.nn.Module): def forward(self, x): return torch.ops.aten.native_batch_norm.default( @@ -153,7 +246,7 @@ def forward(self, x): input_specs, ) - def test_batchnorm_with_dynamic_shape(self): + def test_native_batchnorm2d_with_dynamic_shape(self): class BatchNorm(torch.nn.Module): def forward(self, x): return torch.ops.aten.native_batch_norm.default(