Skip to content

Commit

Permalink
[PT FE]: support mixed precision in aten::min/max (#19936)
Browse files Browse the repository at this point in the history
* [PT FE]: support mixed precision in aten::min/max

* fix eltwise dtype alignment for float16
  • Loading branch information
eaidova authored Sep 20, 2023
1 parent c67c066 commit 2c88fbf
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 18 deletions.
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op/min_max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ OutputVector translate_max(const NodeContext& context) {
// torch.max(input, other)
if (context.input_is_none(2)) {
auto y = context.get_input(1);
align_eltwise_input_types(context, x, y, true);
return {context.mark_node(std::make_shared<v1::Maximum>(x, y))};
}
// torch.max(input, dim, keepdim), returns values and indicies
Expand Down Expand Up @@ -62,6 +63,7 @@ OutputVector translate_min(const NodeContext& context) {
// torch.min(input, other)
if (context.input_is_none(2)) {
auto y = context.get_input(1);
align_eltwise_input_types(context, x, y, true);
return {context.mark_node(std::make_shared<v1::Minimum>(x, y))};
}
// torch.min(input, dim, keepdim), returns values and indicies
Expand Down
8 changes: 4 additions & 4 deletions src/frontends/pytorch/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,9 @@ void align_eltwise_input_types(const NodeContext& context, Output<Node>& lhs, Ou
// if div we need to also align float types to highest bitness regardless of scalar
if (!align_scalars)
lhs_dst_type = element::f32;
rhs_dst_type = element::f32;
rhs_dst_type = lhs_type;
} else if (is_rhs_scalar && !lhs_type.is_real() && rhs_type.is_real()) {
lhs_dst_type = element::f32;
lhs_dst_type = rhs_type;
// if div we need to also align float types to highest bitness regardless of scalar
if (!align_scalars)
rhs_dst_type = element::f32;
Expand All @@ -437,9 +437,9 @@ void align_eltwise_input_types(const NodeContext& context, Output<Node>& lhs, Ou
}

if (!lhs_dst_type.is_real() && rhs_dst_type.is_real()) {
lhs_dst_type = element::f32;
lhs_dst_type = rhs_dst_type;
} else if (lhs_dst_type.is_real() && !rhs_dst_type.is_real()) {
rhs_dst_type = element::f32;
rhs_dst_type = lhs_dst_type;
}
// Align bool to other type
if (lhs_dst_type == element::boolean) {
Expand Down
7 changes: 4 additions & 3 deletions tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class PytorchLayerTest:
"int64": Type.i64,
"int16": Type.i16,
"int8": Type.i8,
"uint8": Type.u8
"uint8": Type.u8,
"float16": Type.f16
}

@staticmethod
Expand Down Expand Up @@ -120,8 +121,8 @@ def use_torch_compile_backend():
continue
assert ov_type == fw_type, f"dtype validation failed: {ov_type} != {fw_type}"
continue
assert torch.tensor(np.array(
ov_tensor)).dtype == fw_tensor.dtype, f"dtype validation failed: {torch.tensor(np.array(ov_tensor)).dtype} != {fw_tensor.dtype}"
ov_tensor_fw_format = torch.tensor(np.array(ov_tensor))
assert ov_tensor_fw_format.dtype == fw_tensor.dtype, f"dtype validation failed: {ov_tensor_fw_format.dtype} != {fw_tensor.dtype}"

# Compare Ie results with Framework results
fw_eps = custom_eps if precision == 'FP32' else 5e-2
Expand Down
6 changes: 6 additions & 0 deletions tests/layer_tests/pytorch_tests/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ def forward3(self, lhs, rhs):
[torch.float32, torch.int32],
[torch.float32, torch.int64],
[torch.float32, torch.float64],
[torch.float16, torch.uint8],
[torch.uint8, torch.float16],
[torch.float16, torch.int32],
[torch.int32, torch.float16],
[torch.float16, torch.int64],
[torch.int64, torch.float16]
])
@pytest.mark.parametrize(("lhs_shape", "rhs_shape"), [([2, 3], [2, 3]),
([2, 3], []),
Expand Down
6 changes: 6 additions & 0 deletions tests/layer_tests/pytorch_tests/test_div.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ def forward3(self, lhs, rhs):
[torch.float32, torch.int32],
[torch.float32, torch.int64],
[torch.float32, torch.float64],
[torch.float16, torch.uint8],
[torch.uint8, torch.float16],
[torch.float16, torch.int32],
[torch.int32, torch.float16],
[torch.float16, torch.int64],
[torch.int64, torch.float16]
])
@pytest.mark.parametrize(("lhs_shape", "rhs_shape"), [([2, 3], [2, 3]),
([2, 3], []),
Expand Down
40 changes: 29 additions & 11 deletions tests/layer_tests/pytorch_tests/test_min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,27 @@


class TestMinMax(PytorchLayerTest):
def _prepare_input(self, second_input=False):
def _prepare_input(self, input_dtype="float32", second_input=False, second_input_dtype="float32"):
import numpy as np
if not second_input:
return (np.random.randn(1, 3, 10, 10).astype(np.float32),)
return (np.random.randn(1, 3, 10, 10).astype(np.float32), np.random.randn(1, 3, 10, 10).astype(np.float32))
return (np.random.randn(1, 3, 10, 10).astype(input_dtype),)
return (np.random.randn(1, 3, 10, 10).astype(input_dtype), np.random.randn(1, 3, 10, 10).astype(second_input_dtype))

def create_model(self, op_type, axes, keep_dims, single_input=True):
def create_model(self, op_type, axes, keep_dims, single_input=True, dtypes=("float32", "float32")):
import torch
op_types = {
'max': torch.max,
'min': torch.min
}

dtypes_map = {
"float32": torch.float32,
"float64": torch.float64,
"int32": torch.int32,
"int64": torch.int64,
"uint8": torch.uint8
}

op = op_types[op_type]

class aten_min_max(torch.nn.Module):
Expand All @@ -41,17 +49,23 @@ def forward(self, x):
return self.op(x, self.axes, self.keep_dims)

class aten_min_max_2args(torch.nn.Module):
def __init__(self, op):
def __init__(self, op, l_dtype, r_dtype):
super(aten_min_max_2args, self).__init__()
self.op = op
self.l_dtype = l_dtype
self.r_dtype = r_dtype

def forward(self, x, y):
return self.op(x, y)
return self.op(x.to(self.l_dtype), y.to(self.r_dtype))

ref_net = None
if axes is None and keep_dims is None:
model_cls = aten_min_max(
op) if single_input else aten_min_max_2args(op)
if single_input:
model_cls = aten_min_max(op)
else:
l_dtype = dtypes_map[dtypes[0]]
r_dtype = dtypes_map[dtypes[1]]
model_cls = aten_min_max_2args(op, l_dtype, r_dtype)
else:
model_cls = aten_min_max_3args(op, axes, keep_dims)

Expand All @@ -66,11 +80,15 @@ def test_reduce_min_max(self, axes, keep_dims, op_type, ie_device, precision, ir
single_input=True), ie_device, precision, ir_version)

@pytest.mark.parametrize("op_type", ['min', 'max'])
@pytest.mark.parametrize("second_input_dtype", ["float32", "int32", "float64", "int64", "uint8"])
@pytest.mark.parametrize("first_input_dtype", ["float32", "int32", "float64", "int64", "uint8"])
@pytest.mark.nightly
@pytest.mark.precommit
def test_min_max(self, op_type, ie_device, precision, ir_version):
self._test(*self.create_model(op_type, None, None, single_input=False),
ie_device, precision, ir_version, kwargs_to_prepare_input={"second_input": True})
def test_min_max(self, op_type, first_input_dtype, second_input_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(op_type, None, None, single_input=False, dtypes=(first_input_dtype, second_input_dtype)),
ie_device, precision, ir_version, kwargs_to_prepare_input=
{"second_input": True, "input_dtype": first_input_dtype, "second_input_dtype": second_input_dtype}
)


class TestPrimMax(PytorchLayerTest):
Expand Down
6 changes: 6 additions & 0 deletions tests/layer_tests/pytorch_tests/test_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ def forward3(self, lhs, rhs):
[torch.float32, torch.int32],
[torch.float32, torch.int64],
[torch.float32, torch.float64],
[torch.float16, torch.uint8],
[torch.uint8, torch.float16],
[torch.float16, torch.int32],
[torch.int32, torch.float16],
[torch.float16, torch.int64],
[torch.int64, torch.float16]
])
@pytest.mark.parametrize(("lhs_shape", "rhs_shape"), [([2, 3], [2, 3]),
([2, 3], []),
Expand Down

0 comments on commit 2c88fbf

Please sign in to comment.