Skip to content

Commit

Permalink
[PT FE]: add translation for aten::cumsum (#16092)
Browse files Browse the repository at this point in the history
* [PT FE]: add translation for aten::cumsum

* handle out and prim::dtype
  • Loading branch information
eaidova authored Mar 8, 2023
1 parent cd8999d commit 6514b76
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 8 deletions.
46 changes: 46 additions & 0 deletions src/frontends/pytorch/src/op/cumsum.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/cum_sum.hpp"
#include "pt_framework_node.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_cumsum(NodeContext& context) {
// aten::cumsum(Tensor self, int dim, *, ScalarType? dtype=None, Tensor out=None)
num_inputs_check(context, 2, 4);
auto x = context.get_input(0);
auto dim = context.get_input(1);
if (!context.input_is_none(2)) {
if (std::dynamic_pointer_cast<v0::Constant>(context.get_input_from_visible_context(2).get_node_shared_ptr())) {
auto dtype = convert_dtype(context.const_input<int64_t>(2));
x = context.mark_node(std::make_shared<v0::Convert>(x, dtype));
} else if (const auto& fw_node = cast_fw_node(context.get_input(2).get_node_shared_ptr(), "prim::dtype")) {
auto out_tensor = fw_node->input_value(0);
x = context.mark_node(std::make_shared<v1::ConvertLike>(x, out_tensor));
} else {
FRONT_END_OP_CONVERSION_CHECK(false, "Couldn't get dtype input");
}
}
auto result = context.mark_node(std::make_shared<v0::CumSum>(x, dim));
if (!context.input_is_none(3)) {
context.mutate_input(3, result);
}
return {result};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
3 changes: 2 additions & 1 deletion src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ OP_CONVERTER(translate_conv_transposend);
OP_CONVERTER(translate_convnd);
OP_CONVERTER(translate_convolution);
OP_CONVERTER(translate_convolution_mode);
OP_CONVERTER(translate_cumsum);
OP_CONVERTER(translate_dim);
OP_CONVERTER(translate_div);
OP_CONVERTER(translate_elu);
Expand Down Expand Up @@ -186,7 +187,7 @@ const std::map<std::string, PytorchCreatorFunction> get_supported_ops() {
{"aten::cos_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Cos>>},
{"aten::cosh", op::translate_1to1_match_1_inputs<opset10::Cosh>},
{"aten::cosh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Cosh>>},
{"aten::cumsum", op::translate_1to1_match_2_inputs<opset10::CumSum>},
{"aten::cumsum", op::translate_cumsum},
{"aten::dim", op::translate_dim},
{"aten::div", op::translate_div},
{"aten::div_", op::inplace_op<op::translate_div>},
Expand Down
54 changes: 47 additions & 7 deletions tests/layer_tests/pytorch_tests/test_cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,67 @@


class TestCumSum(PytorchLayerTest):
def _prepare_input(self):
def _prepare_input(self, out=False, out_dtype=None):
import numpy as np
return (np.random.randn(1, 3, 224, 224).astype(np.float32),)
x = np.random.randn(1, 3, 224, 224).astype(np.float32)
if not out:
return (x, )
y = np.random.randn(1, 3, 224, 224).astype(np.float32)
if out_dtype is not None:
y = y.astype(out_dtype)
return (x, y)

def create_model(self, axis):

def create_model(self, axis, dtype_str, out, dtype_from_input):
import torch

dtypes = {
"float32": torch.float32,
"float64": torch.float64,
"int32": torch.int32,
"int64": torch.int64,
"int8": torch.int8,
"uint8": torch.uint8
}

dtype = dtypes.get(dtype_str)

class aten_cumsum(torch.nn.Module):
def __init__(self, axis):
def __init__(self, axis, dtype, out=False, dtype_from_input=False):
super(aten_cumsum, self).__init__()
self.axis = axis
self.dtype = dtype
if dtype_from_input:
self.forward_out = self.forward_out_prim_dtype
if out:
self.forward = self.forward_out
if self.dtype is not None:
if not dtype_from_input:
self.forward = self.forward_dtype if not out else self.forward_out_dtype

def forward(self, x):
return torch.cumsum(x, self.axis)

def forward_dtype(self, x):
return torch.cumsum(x, self.axis, dtype=self.dtype)

def forward_out(self, x, y):
return y, torch.cumsum(x, self.axis, out=y)

def forward_out_dtype(self, x, y):
return y, torch.cumsum(x, self.axis, dtype=self.dtype, out=y)

def forward_out_prim_dtype(self, x, y):
return y, torch.cumsum(x, self.axis, dtype=y.dtype, out=y)

ref_net = None

return aten_cumsum(axis), ref_net, "aten::cumsum"
return aten_cumsum(axis, dtype, out, dtype_from_input), ref_net, "aten::cumsum"

@pytest.mark.parametrize("axis", [0, 1, 2, 3, -1, -2, -3, -4])
@pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64", "int8"])
@pytest.mark.parametrize("out,dtype_from_input", [(False, False), (True, False), (True, True)])
@pytest.mark.nightly
@pytest.mark.precommit
def test_cumsum(self, axis, ie_device, precision, ir_version):
self._test(*self.create_model(axis), ie_device, precision, ir_version)
def test_cumsum(self, axis, dtype, out, dtype_from_input, ie_device, precision, ir_version):
self._test(*self.create_model(axis, dtype, out, dtype_from_input), ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out, "out_dtype": dtype})

0 comments on commit 6514b76

Please sign in to comment.