Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add handling for ITensor mean and var in batch_norm #3099

Merged
merged 3 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 75 additions & 35 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
from torch_tensorrt.dynamo.conversion.impl.cat import cat
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
from torch_tensorrt.dynamo.types import TRTTensor
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.fx.types import TRTTensor
from torch_tensorrt.fx.utils import get_dynamic_dims

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand All @@ -34,61 +33,102 @@ def batch_norm(
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
weight: Optional[Union[torch.Tensor, np.ndarray]],
bias: Optional[Union[torch.Tensor, np.ndarray]],
running_mean: Optional[Union[torch.Tensor, np.ndarray]],
running_var: Optional[Union[torch.Tensor, np.ndarray]],
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
running_mean: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
running_var: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
training: bool,
momentum: float,
eps: float,
cudnn_enabled: bool,
return_mean_rstd: bool,
) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]:

if has_dynamic_shape(input.shape):
assert input.shape[1] != -1, "Channel dim can't be dynamic for batch norm."

if weight is None:
weight = 1.0
# Save the original output shape for later use
output_shape = input.shape

if weight is None:
weight = get_trt_tensor(ctx, 1.0, f"{name}_weight")
if bias is None:
bias = 0.0

bias = get_trt_tensor(ctx, 0.0, f"{name}_bias")
if running_mean is None:
running_mean = 0.0

running_mean = get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
if running_var is None:
running_var = 1.0
running_var = get_trt_tensor(ctx, 1.0, f"{name}_running_var")

scale = to_numpy(weight) / np.sqrt(to_numpy(running_var) + eps)
bias = to_numpy(bias) - to_numpy(running_mean) * scale
power = np.ones_like(scale)
# eps_tensor for numerical stability
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps")

# For BatchNorm1d, reshape 1d to 2d
output_shape = input.shape
if len(input.shape) < 4:
assert (
len(get_dynamic_dims(input.shape)) <= 1
), "BatchNorm1D with more than one dynamic dims is not currently supported."
new_shape = (
(input.shape[0], input.shape[1], 1, 1)
if len(input.shape) == 2
else (input.shape[0], input.shape[1], input.shape[2], 1)
)
input = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_2d", input, new_shape
)
layer = ctx.net.add_scale(input, trt.ScaleMode.CHANNEL, bias, scale, power)
set_layer_name(layer, target, name, source_ir)
output = layer.get_output(0)
# adjusted_var = running_var + eps
adjusted_var = impl.elementwise.add(
ctx, target, source_ir, f"{name}_adjusted_var", running_var, eps_tensor
)

# sqrt_adjusted_var = sqrt(adjusted_var)
sqrt_adjusted_var = impl.unary.sqrt(
ctx, target, source_ir, f"{name}_sqrt", adjusted_var
)

# scale = weight / sqrt_adjusted_var
scale = impl.elementwise.div(
ctx, target, source_ir, f"{name}_scale", weight, sqrt_adjusted_var
)

# scaled_running_mean = running_mean * scale
scaled_running_mean = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_scaled_running_mean", running_mean, scale
)

# bias_adjusted = bias - scaled_running_mean
bias_adjusted = impl.elementwise.sub(
ctx, target, source_ir, f"{name}_bias_adjusted", bias, scaled_running_mean
)

# Reshape scale and bias_adjusted to match input shape for broadcasting
expanded_shape = [1] * len(output_shape)
expanded_shape[1] = output_shape[1] # Set channel dimension

scale_reshape = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_scale",
scale,
tuple(expanded_shape),
)
bias_adjusted_reshape = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_bias",
bias_adjusted,
tuple(expanded_shape),
)

# Apply the scale and bias to the input
scaled_input = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_scaled_input", input, scale_reshape
)
output = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_output",
scaled_input,
bias_adjusted_reshape,
)

# For BatchNorm1d, reshape output back to 1d
# For BatchNorm1d, reshape output back to original shape if necessary
if len(output_shape) < 4:
output = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_1d",
layer.get_output(0),
output,
output_shape,
)

Expand Down
105 changes: 99 additions & 6 deletions tests/py/dynamo/conversion/test_batch_norm_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,82 @@


class TestBatchNormConverter(DispatchTestCase):
def test_batchnorm(self):
def test_batchnorm_static_weights(self):
class BatchNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.batch_norm.default(
x,
torch.full((FEATURE_NUM,), 3, dtype=torch.float32),
torch.zeros((FEATURE_NUM,)),
torch.zeros((FEATURE_NUM,)),
torch.full((FEATURE_NUM,), 3, dtype=torch.float32),
False,
0.1,
1e-05,
True,
)

inputs = [torch.randn(1, 3, 224, 224)]
self.run_test(
BatchNorm(),
inputs,
)

def test_batchnorm_ITensor_weights_bias(self):
class BatchNorm(torch.nn.Module):
def forward(self, x, weight, bias):
return torch.ops.aten.batch_norm.default(
x,
weight,
bias,
torch.zeros((FEATURE_NUM,)),
torch.ones((FEATURE_NUM,)),
False,
0.1,
1e-05,
True,
)

inputs = [
torch.randn(1, 3, 224, 224),
torch.ones((FEATURE_NUM,)),
torch.zeros((FEATURE_NUM,)),
]
self.run_test(
BatchNorm(),
inputs,
)

def test_batchnorm_ITensor_weights(self):
class BatchNorm(torch.nn.Module):
def forward(self, x, weight):
return torch.ops.aten.batch_norm.default(
x,
weight,
None,
torch.zeros((FEATURE_NUM,)),
torch.ones((FEATURE_NUM,)),
False,
0.1,
1e-05,
True,
)

inputs = [
torch.randn(1, 3, 224, 224),
torch.ones((FEATURE_NUM,)),
]
self.run_test(
BatchNorm(),
inputs,
)

def test_batchnorm_static_bias_only(self):
class BatchNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.batch_norm.default(
x,
None,
torch.zeros((FEATURE_NUM,)),
torch.zeros((FEATURE_NUM,)),
torch.ones((FEATURE_NUM,)),
Expand Down Expand Up @@ -57,7 +127,7 @@ def forward(self, x):
input_specs,
)

def test_batchnorm_with_dynamic_shape(self):
def test_batchnorm2d_with_dynamic_shape(self):
class BatchNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.batch_norm.default(
Expand Down Expand Up @@ -87,7 +157,7 @@ def forward(self, x):


class TestNativeBatchNormConverter(DispatchTestCase):
def test_batchnorm(self):
def test_native_batchnorm_static_weights(self):
class BatchNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.native_batch_norm.default(
Expand All @@ -107,7 +177,30 @@ def forward(self, x):
inputs,
)

def test_batchnorm_legit_no_training(self):
def test_native_batchnorm_legit_no_training_with_trt_tensor(self):
class BatchNorm(torch.nn.Module):
def forward(self, x, running_mean, running_var):
return torch.ops.aten._native_batch_norm_legit_no_training.default(
x,
torch.ones((FEATURE_NUM,)),
torch.zeros((FEATURE_NUM,)),
running_mean,
running_var,
0.1,
1e-05,
)[0]

inputs = [
torch.randn(1, 3, 224, 224),
torch.zeros((FEATURE_NUM,)),
torch.ones((FEATURE_NUM,)),
]
self.run_test(
BatchNorm(),
inputs,
)

def test_native_batchnorm_legit_no_training_with_static_means(self):
class BatchNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten._native_batch_norm_legit_no_training.default(
Expand All @@ -126,7 +219,7 @@ def forward(self, x):
inputs,
)

def test_batchnorm1d_with_dynamic_shape(self):
def test_native_batchnorm1d_with_dynamic_shape(self):
class BatchNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.native_batch_norm.default(
Expand All @@ -153,7 +246,7 @@ def forward(self, x):
input_specs,
)

def test_batchnorm_with_dynamic_shape(self):
def test_native_batchnorm2d_with_dynamic_shape(self):
class BatchNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.native_batch_norm.default(
Expand Down
Loading