Skip to content

Commit

Permalink
chore: Ensure input arguments are based on ITensor (TRTTensor)
Browse files Browse the repository at this point in the history
  • Loading branch information
chohk88 committed Aug 22, 2024
1 parent 3c06a9f commit 5d6cf81
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 101 deletions.
165 changes: 64 additions & 101 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
from torch_tensorrt.dynamo.conversion.impl.cat import cat
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
from torch_tensorrt.dynamo.types import TRTTensor
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.fx.types import TRTTensor
from torch_tensorrt.fx.utils import get_dynamic_dims

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand All @@ -34,8 +33,8 @@ def batch_norm(
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
weight: Optional[Union[torch.Tensor, np.ndarray]],
bias: Optional[Union[torch.Tensor, np.ndarray]],
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
running_mean: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
running_var: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
training: bool,
Expand All @@ -51,112 +50,76 @@ def batch_norm(
# Save the original output shape for later use
output_shape = input.shape

# Handle case when running_mean or running_var is TRTTensor
if isinstance(running_mean, TRTTensor) or isinstance(running_var, TRTTensor):
# Default values if weight, bias, running_mean, running_var are None
if weight is None:
weight = get_trt_tensor(ctx, 1.0, f"{name}_weight")
if bias is None:
bias = get_trt_tensor(ctx, 0.0, f"{name}_bias")
if running_mean is None:
running_mean = get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
if running_var is None:
running_var = get_trt_tensor(ctx, 1.0, f"{name}_running_var")

# eps_tensor for numerical stability
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps")

# adjusted_var = running_var + eps
adjusted_var = impl.elementwise.add(
ctx, target, source_ir, f"{name}_adjusted_var", running_var, eps_tensor
)
if weight is None:
weight = get_trt_tensor(ctx, 1.0, f"{name}_weight")
if bias is None:
bias = get_trt_tensor(ctx, 0.0, f"{name}_bias")
if running_mean is None:
running_mean = get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
if running_var is None:
running_var = get_trt_tensor(ctx, 1.0, f"{name}_running_var")

# sqrt_adjusted_var = sqrt(adjusted_var)
sqrt_adjusted_var = impl.unary.sqrt(
ctx, target, source_ir, f"{name}_sqrt", adjusted_var
)
# eps_tensor for numerical stability
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps")

# scale = weight / sqrt_adjusted_var
scale = impl.elementwise.div(
ctx, target, source_ir, f"{name}_scale", weight, sqrt_adjusted_var
)
# adjusted_var = running_var + eps
adjusted_var = impl.elementwise.add(
ctx, target, source_ir, f"{name}_adjusted_var", running_var, eps_tensor
)

# scaled_running_mean = running_mean * scale
scaled_running_mean = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_scaled_running_mean", running_mean, scale
)
# sqrt_adjusted_var = sqrt(adjusted_var)
sqrt_adjusted_var = impl.unary.sqrt(
ctx, target, source_ir, f"{name}_sqrt", adjusted_var
)

# bias_adjusted = bias - scaled_running_mean
bias_adjusted = impl.elementwise.sub(
ctx, target, source_ir, f"{name}_bias_adjusted", bias, scaled_running_mean
)
# scale = weight / sqrt_adjusted_var
scale = impl.elementwise.div(
ctx, target, source_ir, f"{name}_scale", weight, sqrt_adjusted_var
)

# Reshape scale and bias_adjusted to match input shape for broadcasting
expanded_shape = [1] * len(output_shape)
expanded_shape[1] = output_shape[1] # Set channel dimension
# scaled_running_mean = running_mean * scale
scaled_running_mean = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_scaled_running_mean", running_mean, scale
)

scale_reshape = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_scale",
scale,
tuple(expanded_shape),
)
bias_adjusted_reshape = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_bias",
bias_adjusted,
tuple(expanded_shape),
)
# bias_adjusted = bias - scaled_running_mean
bias_adjusted = impl.elementwise.sub(
ctx, target, source_ir, f"{name}_bias_adjusted", bias, scaled_running_mean
)

# Apply the scale and bias to the input
scaled_input = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_scaled_input", input, scale_reshape
)
output = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_output",
scaled_input,
bias_adjusted_reshape,
)
# Reshape scale and bias_adjusted to match input shape for broadcasting
expanded_shape = [1] * len(output_shape)
expanded_shape[1] = output_shape[1] # Set channel dimension

else:
# Handle the case when running_mean and running_var are not TRTTensor
if weight is None:
weight = 1.0
if bias is None:
bias = 0.0
if running_mean is None:
running_mean = 0.0
if running_var is None:
running_var = 1.0

scale = to_numpy(weight) / np.sqrt(to_numpy(running_var) + eps)
bias = to_numpy(bias) - to_numpy(running_mean) * scale
power = np.ones_like(scale)

# For BatchNorm1d, reshape 1d to 2d
if len(output_shape) < 4:
assert (
len(get_dynamic_dims(output_shape)) <= 1
), "BatchNorm1D with more than one dynamic dim is not currently supported."
new_shape = (
(output_shape[0], output_shape[1], 1, 1)
if len(output_shape) == 2
else (output_shape[0], output_shape[1], output_shape[2], 1)
)
input = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_2d", input, new_shape
)
scale_reshape = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_scale",
scale,
tuple(expanded_shape),
)
bias_adjusted_reshape = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_bias",
bias_adjusted,
tuple(expanded_shape),
)

layer = ctx.net.add_scale(input, trt.ScaleMode.CHANNEL, bias, scale, power)
set_layer_name(layer, target, name, source_ir)
output = layer.get_output(0)
# Apply the scale and bias to the input
scaled_input = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_scaled_input", input, scale_reshape
)
output = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_output",
scaled_input,
bias_adjusted_reshape,
)

# For BatchNorm1d, reshape output back to original shape if necessary
if len(output_shape) < 4:
Expand Down
49 changes: 49 additions & 0 deletions tests/py/dynamo/conversion/test_batch_norm_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,55 @@ def forward(self, x):
inputs,
)

def test_batchnorm_ITensor_weights_bias(self):
class BatchNorm(torch.nn.Module):
def forward(self, x, weight, bias):
return torch.ops.aten.batch_norm.default(
x,
weight,
bias,
torch.zeros((FEATURE_NUM,)),
torch.ones((FEATURE_NUM,)),
False,
0.1,
1e-05,
True,
)

inputs = [
torch.randn(1, 3, 224, 224),
torch.ones((FEATURE_NUM,)),
torch.zeros((FEATURE_NUM,)),
]
self.run_test(
BatchNorm(),
inputs,
)

def test_batchnorm_ITensor_weights(self):
class BatchNorm(torch.nn.Module):
def forward(self, x, weight):
return torch.ops.aten.batch_norm.default(
x,
weight,
None,
torch.zeros((FEATURE_NUM,)),
torch.ones((FEATURE_NUM,)),
False,
0.1,
1e-05,
True,
)

inputs = [
torch.randn(1, 3, 224, 224),
torch.ones((FEATURE_NUM,)),
]
self.run_test(
BatchNorm(),
inputs,
)

def test_batchnorm_static_bias_only(self):
class BatchNorm(torch.nn.Module):
def forward(self, x):
Expand Down

0 comments on commit 5d6cf81

Please sign in to comment.