diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 8846497348..c988eaccee 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -563,7 +563,7 @@ def aten_ops_rsqrt( ) -@dynamo_tensorrt_converter(torch.ops.aten.neg.default) +@dynamo_tensorrt_converter(torch.ops.aten.neg.default, supports_dynamic_shapes=True) def aten_ops_neg( ctx: ConversionContext, target: Target, diff --git a/tests/py/dynamo/conversion/test_neg_aten.py b/tests/py/dynamo/conversion/test_neg_aten.py index 795a78354f..a0439c02bc 100644 --- a/tests/py/dynamo/conversion/test_neg_aten.py +++ b/tests/py/dynamo/conversion/test_neg_aten.py @@ -42,6 +42,46 @@ def forward(self, input): check_dtype=False, ) + @parameterized.expand( + [ + ( + "2d_dim_dtype_half", + (1, 1), + (2, 2), + (4, 4), + torch.half, + torch.half, + ), + ( + "3d_dim_dtype_float", + (1, 1, 1), + (1, 2, 3), + (3, 3, 3), + torch.float, + torch.float, + ), + ] + ) + def test_dynamic_shape_neg( + self, _, min_shape, opt_shape, max_shape, type, output_type + ): + class neg(nn.Module): + def forward(self, input): + return torch.ops.aten.neg.default(input) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=type, + ), + ] + + self.run_test_with_dynamic_shape( + neg(), input_specs, output_dtypes=[output_type] + ) + if __name__ == "__main__": run_tests()