diff --git a/tests/layer_tests/pytorch_tests/test_bucketize.py b/tests/layer_tests/pytorch_tests/test_bucketize.py index 359d8fff86e9fd..29fb550708e464 100644 --- a/tests/layer_tests/pytorch_tests/test_bucketize.py +++ b/tests/layer_tests/pytorch_tests/test_bucketize.py @@ -10,15 +10,10 @@ class TestBucketize(PytorchLayerTest): - def _prepare_input(self, input_shape, boundaries_range, input_dtype, boundaries_dtype, out_int32): - if out_int32: - output_dtype = "int32" - else: - output_dtype = "int64" + def _prepare_input(self, input_shape, boundaries_range, input_dtype, boundaries_dtype): return ( np.random.randn(*input_shape).astype(input_dtype), - np.arange(*boundaries_range).astype(boundaries_dtype), - np.zeros(input_shape).astype(output_dtype)) + np.arange(*boundaries_range).astype(boundaries_dtype)) def create_model(self, out_int32, right, is_out): class aten_bucketize(torch.nn.Module): @@ -29,8 +24,10 @@ def __init__(self, out_int32, right, is_out) -> None: self.right = right self.is_out = is_out - def forward(self, input, boundaries, output): + def forward(self, input, boundaries): if self.is_out: + output_dtype = torch.int32 if self.out_int32 else torch.int64 + output = torch.zeros_like(input, dtype=output_dtype) torch.bucketize(input, boundaries, out_int32=self.out_int32, right=self.right, out=output) return output else: @@ -53,5 +50,4 @@ def test_bucketize(self, input_shape, boundaries_range, input_dtype, boundaries_ self._test(*self.create_model(out_int32, right, is_out), ie_device, precision, ir_version, kwargs_to_prepare_input={ "input_shape": input_shape, "input_dtype": input_dtype, "boundaries_range": boundaries_range, "boundaries_dtype": boundaries_dtype, - "out_int32": out_int32, - }) \ No newline at end of file + })