Skip to content

Commit

Permalink
[PT FE] Support aten::aminmax for pytorch models (openvinotoolkit#23879)
Browse files Browse the repository at this point in the history
### Details:
 - Implemented `aten::aminmax` operation
 - Implemented test for aminmax op
 - registered inside `op_table.cpp`

### Tickets:
 - openvinotoolkit#23327

---------

Co-authored-by: Maxim Vafin <[email protected]>
  • Loading branch information
LucaTamSapienza and mvafin authored Apr 15, 2024
1 parent 19d208e commit e3fbb57
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 0 deletions.
20 changes: 20 additions & 0 deletions src/frontends/pytorch/src/op/min_max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,26 @@ OutputVector translate_amax(const NodeContext& context) {
return {res};
}

OutputVector translate_aminmax(const NodeContext& context) {
num_inputs_check(context, 1, 4); // Expect between 1 and 4 inputs
// (input tensor, dim = none, keepdim = false, out = none)

auto input = context.get_input(0);

// check if dim is provided, if not, get the range of axes to compute min and max
auto dim = !context.input_is_none(1) ? context.get_input(1) : get_axes_range(context, 0);

// check if keepdim is provided, if not, set it to false like PyTorch
bool keep_dims = !context.input_is_none(2) ? context.const_input<bool>(2) : false;

auto amin = context.mark_node(std::make_shared<v1::ReduceMin>(input, dim, keep_dims));
auto amax = context.mark_node(std::make_shared<v1::ReduceMax>(input, dim, keep_dims));

PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(3), "out argument is not supported for aten::aminmax");

return {amin, amax};
}

} // namespace op
} // namespace pytorch
} // namespace frontend
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ OP_CONVERTER(translate_adaptive_max_pool2d);
OP_CONVERTER(translate_adaptive_max_pool1d);
OP_CONVERTER(translate_add);
OP_CONVERTER(translate_add_);
OP_CONVERTER(translate_aminmax);
OP_CONVERTER(translate_mul);
OP_CONVERTER(translate_mul_);
OP_CONVERTER(translate_addcmul);
Expand Down Expand Up @@ -352,6 +353,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::all", op::translate_all},
{"aten::amax", op::translate_amax},
{"aten::amin", op::translate_amin},
{"aten::aminmax", op::translate_aminmax},
// aten::append - Supported in limited set of patterns
{"aten::arange", op::translate_arange},
{"aten::argmax", op::translate_argmax},
Expand Down
60 changes: 60 additions & 0 deletions tests/layer_tests/pytorch_tests/test_aminmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from pytorch_layer_test_class import PytorchLayerTest

class TestAminMax(PytorchLayerTest):
def _prepare_input(self, inputs, dtype=None):
import numpy as np
return [np.array(inputs).astype(dtype)]

def create_model(self, dtype=None, dim=None, keepdim=False):
dtype_map = {
"float32": torch.float32,
"float64": torch.float64,
"int32": torch.int32,
"int64": torch.int64,
}

dtype = dtype_map.get(dtype)

class aten_aminmax(torch.nn.Module):
def __init__(self, dtype, dim, keepdim):
super().__init__()
self.dtype = dtype
self.dim = dim
self.keepdim = keepdim

def forward(self, x):
return torch.aminmax(x.to(self.dtype), dim=self.dim, keepdim=self.keepdim, out=None)

model_class = aten_aminmax(dtype, dim, keepdim)

ref_net = None

return model_class, ref_net, "aten::aminmax"

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64"])
@pytest.mark.parametrize("inputs", [[0, 1, 2, 3, 4, -1],
[-2, -1, 0, 1, 2, 3],
[1, 2, 3, 4, 5, 6]])
@pytest.mark.parametrize("dim,keepdim", [(None, False), # Test with default arguments
(0, False), # Test with dim provided and keepdim=False
(0, True), # Test with dim provided and keepdim=True
(None, True)]) # Test with keepdim=True and dim not provided
def test_aminmax(self, dtype, inputs, ie_device,
precision, ir_version, dim, keepdim):
self._test(
*self.create_model(dtype=dtype, dim=dim, keepdim=keepdim),
ie_device,
precision,
ir_version,
trace_model=True,
freeze_model=False,
kwargs_to_prepare_input={"inputs": inputs, "dtype": dtype}
)

0 comments on commit e3fbb57

Please sign in to comment.