Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE] Support aten::aminmax for pytorch models #23879

Merged
merged 8 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/frontends/pytorch/src/op/min_max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,26 @@ OutputVector translate_amax(const NodeContext& context) {
return {res};
}

OutputVector translate_aminmax(const NodeContext& context) {
num_inputs_check(context, 1, 4); // Expect between 1 and 4 inputs
// (input tensor, dim = none, keepdim = false, out = none)

auto input = context.get_input(0);

// check if dim is provided, if not, get the range of axes to compute min and max
auto dim = !context.input_is_none(1) ? context.get_input(1) : get_axes_range(context, 0);

// check if keepdim is provided, if not, set it to false like PyTorch
bool keep_dims = !context.input_is_none(2) ? context.const_input<bool>(2) : false;

auto amin = context.mark_node(std::make_shared<v1::ReduceMin>(input, dim, keep_dims));
auto amax = context.mark_node(std::make_shared<v1::ReduceMax>(input, dim, keep_dims));

PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(3), "out argument is not supported for aten::aminmax");

return {amin, amax};
}

} // namespace op
} // namespace pytorch
} // namespace frontend
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ OP_CONVERTER(translate_adaptive_max_pool2d);
OP_CONVERTER(translate_adaptive_max_pool1d);
OP_CONVERTER(translate_add);
OP_CONVERTER(translate_add_);
OP_CONVERTER(translate_aminmax);
OP_CONVERTER(translate_mul);
OP_CONVERTER(translate_mul_);
OP_CONVERTER(translate_addcmul);
Expand Down Expand Up @@ -352,6 +353,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::all", op::translate_all},
{"aten::amax", op::translate_amax},
{"aten::amin", op::translate_amin},
{"aten::aminmax", op::translate_aminmax},
// aten::append - Supported in limited set of patterns
{"aten::arange", op::translate_arange},
{"aten::argmax", op::translate_argmax},
Expand Down
60 changes: 60 additions & 0 deletions tests/layer_tests/pytorch_tests/test_aminmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from pytorch_layer_test_class import PytorchLayerTest

class TestAminMax(PytorchLayerTest):
def _prepare_input(self, inputs, dtype=None):
import numpy as np
return [np.array(inputs).astype(dtype)]

def create_model(self, dtype=None, dim=None, keepdim=False):
dtype_map = {
"float32": torch.float32,
"float64": torch.float64,
"int32": torch.int32,
"int64": torch.int64,
}

dtype = dtype_map.get(dtype)

class aten_aminmax(torch.nn.Module):
def __init__(self, dtype, dim, keepdim):
super().__init__()
self.dtype = dtype
self.dim = dim
self.keepdim = keepdim

def forward(self, x):
return torch.aminmax(x.to(self.dtype), dim=self.dim, keepdim=self.keepdim, out=None)

model_class = aten_aminmax(dtype, dim, keepdim)

ref_net = None

return model_class, ref_net, "aten::aminmax"

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64"])
@pytest.mark.parametrize("inputs", [[0, 1, 2, 3, 4, -1],
[-2, -1, 0, 1, 2, 3],
[1, 2, 3, 4, 5, 6]])
@pytest.mark.parametrize("dim,keepdim", [(None, False), # Test with default arguments
(0, False), # Test with dim provided and keepdim=False
(0, True), # Test with dim provided and keepdim=True
(None, True)]) # Test with keepdim=True and dim not provided
def test_aminmax(self, dtype, inputs, ie_device,
precision, ir_version, dim, keepdim):
self._test(
*self.create_model(dtype=dtype, dim=dim, keepdim=keepdim),
ie_device,
precision,
ir_version,
trace_model=True,
freeze_model=False,
kwargs_to_prepare_input={"inputs": inputs, "dtype": dtype}
)
Loading