Skip to content

Commit

Permalink
move output from inputs of forward and create it inside
Browse files Browse the repository at this point in the history
  • Loading branch information
awayzjj authored Mar 20, 2024
1 parent 4d7cc13 commit 4f522ac
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions tests/layer_tests/pytorch_tests/test_bucketize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,10 @@

class TestBucketize(PytorchLayerTest):

def _prepare_input(self, input_shape, boundaries_range, input_dtype, boundaries_dtype, out_int32):
if out_int32:
output_dtype = "int32"
else:
output_dtype = "int64"
def _prepare_input(self, input_shape, boundaries_range, input_dtype, boundaries_dtype):
return (
np.random.randn(*input_shape).astype(input_dtype),
np.arange(*boundaries_range).astype(boundaries_dtype),
np.zeros(input_shape).astype(output_dtype))
np.arange(*boundaries_range).astype(boundaries_dtype))

def create_model(self, out_int32, right, is_out):
class aten_bucketize(torch.nn.Module):
Expand All @@ -29,8 +24,10 @@ def __init__(self, out_int32, right, is_out) -> None:
self.right = right
self.is_out = is_out

def forward(self, input, boundaries, output):
def forward(self, input, boundaries):
if self.is_out:
output_dtype = torch.int32 if self.out_int32 else torch.int64
output = torch.zeros_like(input, dtype=output_dtype)
torch.bucketize(input, boundaries, out_int32=self.out_int32, right=self.right, out=output)
return output
else:
Expand All @@ -53,5 +50,4 @@ def test_bucketize(self, input_shape, boundaries_range, input_dtype, boundaries_
self._test(*self.create_model(out_int32, right, is_out), ie_device, precision, ir_version, kwargs_to_prepare_input={
"input_shape": input_shape, "input_dtype": input_dtype,
"boundaries_range": boundaries_range, "boundaries_dtype": boundaries_dtype,
"out_int32": out_int32,
})
})

0 comments on commit 4f522ac

Please sign in to comment.