Skip to content

Commit

Permalink
Merge branch 'pytorch_frontend' into ea/arange
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Dec 29, 2022
2 parents 2a5e9c7 + 42adb07 commit 23088a9
Show file tree
Hide file tree
Showing 10 changed files with 308 additions and 39 deletions.
12 changes: 7 additions & 5 deletions src/frontends/pytorch/src/op/as_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@ namespace op {
OutputVector translate_as_tensor(NodeContext& context) {
auto dtype = element::f32;
Output<Node> cast;
if (!context.input_is_none(1)){
if (!context.input_is_none(1)) {
auto dtype_ext_node = context.get_input_from_visible_context(1).get_node_shared_ptr();
auto dtype_fw_node = std::dynamic_pointer_cast<PtFrameworkNode>(dtype_ext_node);
if (dtype_fw_node && dtype_fw_node->get_op_type() == "prim::dtype") {
auto type_input = dtype_fw_node->input_value(0);
return {context.mark_node(std::make_shared<opset8::ConvertLike>(context.get_input(0), type_input))};
}
if (auto dtype_const = std::dynamic_pointer_cast<opset8::Constant>(dtype_ext_node)){
auto pt_type = dtype_const->cast_vector<int64_t>()[0];
FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::as_tensor: ", pt_type);
dtype = TORCH_TO_OV_TYPE.at(pt_type);
if (auto dtype_const = std::dynamic_pointer_cast<opset8::Constant>(dtype_ext_node)) {
auto pt_type = dtype_const->cast_vector<int64_t>()[0];
FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type),
"Unknown type in aten::as_tensor: ",
pt_type);
dtype = TORCH_TO_OV_TYPE.at(pt_type);
}
}
cast = context.mark_node(std::make_shared<opset8::Convert>(context.get_input(0), dtype));
Expand Down
58 changes: 28 additions & 30 deletions src/frontends/pytorch/src/op/full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ OutputVector translate_full(NodeContext& context) {

auto filled_tensor = context.mark_node(std::make_shared<opset8::Broadcast>(value, sizes));
if (num_inputs < 6) {
size_t out_id = num_inputs == 3 ? 2: 3;
if (!context.input_is_none(out_id)){
auto out = context.get_input(out_id);
return {context.mark_node(std::make_shared<opset8::ConvertLike>(filled_tensor, out))};
size_t out_id = num_inputs == 3 ? 2 : 3;
if (!context.input_is_none(out_id)) {
auto out = context.get_input(out_id);
return {context.mark_node(std::make_shared<opset8::ConvertLike>(filled_tensor, out))};
}
}
size_t dtype_id = num_inputs == 6 ? 2: 3;
if (!context.input_is_none(dtype_id)){
size_t dtype_id = num_inputs == 6 ? 2 : 3;
if (!context.input_is_none(dtype_id)) {
auto pt_type = context.const_input<int64_t>(dtype_id);
FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::full: ", pt_type);
auto dtype = TORCH_TO_OV_TYPE.at(pt_type);
Expand All @@ -39,13 +39,13 @@ OutputVector translate_full_like(NodeContext& context) {
auto value = context.get_input(1);
auto input_shape = context.mark_node(std::make_shared<opset8::ShapeOf>(input));
auto filled_tensor = context.mark_node(std::make_shared<opset8::Broadcast>(value, input_shape));
if (context.get_input_size() == 7 && !context.input_is_none(2)){
if (context.get_input_size() == 7 && !context.input_is_none(2)) {
auto pt_type = context.const_input<int64_t>(2);
FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::full_like: ", pt_type);
auto dtype = TORCH_TO_OV_TYPE.at(pt_type);
filled_tensor = context.mark_node(std::make_shared<opset8::Convert>(filled_tensor, dtype));
} else {
auto out_dtype = context.input_is_none(3)? input : context.get_input(3);
auto out_dtype = context.input_is_none(3) ? input : context.get_input(3);
filled_tensor = context.mark_node(std::make_shared<opset8::ConvertLike>(filled_tensor, out_dtype));
}
return {filled_tensor};
Expand All @@ -71,15 +71,15 @@ OutputVector translate_zeros(NodeContext& context) {
auto filled_tensor = context.mark_node(std::make_shared<opset8::Broadcast>(value, sizes));
int num_inputs = context.get_input_size();
if (num_inputs < 5) {
size_t out_id = num_inputs == 2 ? 1: 2;
if (!context.input_is_none(out_id)){
auto out = context.get_input(out_id);
return {context.mark_node(std::make_shared<opset8::ConvertLike>(filled_tensor, out))};
size_t out_id = num_inputs == 2 ? 1 : 2;
if (!context.input_is_none(out_id)) {
auto out = context.get_input(out_id);
return {context.mark_node(std::make_shared<opset8::ConvertLike>(filled_tensor, out))};
}
return {filled_tensor};
}
size_t dtype_id = num_inputs == 5 ? 1: 2;
if (!context.input_is_none(dtype_id)){
size_t dtype_id = num_inputs == 5 ? 1 : 2;
if (!context.input_is_none(dtype_id)) {
std::cout << dtype_id << std::endl;
auto pt_type = context.const_input<int64_t>(dtype_id);
std::cout << pt_type << std::endl;
Expand All @@ -95,14 +95,13 @@ OutputVector translate_zeros_like(NodeContext& context) {
auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {0}));
auto input_shape = context.mark_node(std::make_shared<opset8::ShapeOf>(input));
auto filled_tensor = context.mark_node(std::make_shared<opset8::Broadcast>(value, input_shape));
if (context.get_input_size() == 6 && !context.input_is_none(1)){
if (context.get_input_size() == 6 && !context.input_is_none(1)) {
auto pt_type = context.const_input<int64_t>(1);
FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::zeros_like: ", pt_type);
auto dtype = TORCH_TO_OV_TYPE.at(pt_type);
filled_tensor = context.mark_node(std::make_shared<opset8::Convert>(filled_tensor, dtype));
}
else {
auto out_dtype = context.input_is_none(2)? input : context.get_input(2);
} else {
auto out_dtype = context.input_is_none(2) ? input : context.get_input(2);
filled_tensor = context.mark_node(std::make_shared<opset8::ConvertLike>(filled_tensor, out_dtype));
}
return {filled_tensor};
Expand All @@ -113,7 +112,7 @@ OutputVector translate_new_zeros(NodeContext& context) {
auto sizes = context.get_input(1);
auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {0}));
auto filled_tensor = context.mark_node(std::make_shared<opset8::Broadcast>(value, sizes));
if (context.get_input_size() == 6 && !context.input_is_none(2)){
if (context.get_input_size() == 6 && !context.input_is_none(2)) {
auto pt_type = context.const_input<int64_t>(2);
FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::new_zeros: ", pt_type);
auto dtype = TORCH_TO_OV_TYPE.at(pt_type);
Expand All @@ -128,14 +127,14 @@ OutputVector translate_ones(NodeContext& context) {
auto filled_tensor = context.mark_node(std::make_shared<opset8::Broadcast>(value, sizes));
int num_inputs = context.get_input_size();
if (num_inputs < 5) {
size_t out_id = num_inputs == 2 ? 1: 2;
if (!context.input_is_none(out_id)){
auto out = context.get_input(out_id);
return {context.mark_node(std::make_shared<opset8::ConvertLike>(filled_tensor, out))};
size_t out_id = num_inputs == 2 ? 1 : 2;
if (!context.input_is_none(out_id)) {
auto out = context.get_input(out_id);
return {context.mark_node(std::make_shared<opset8::ConvertLike>(filled_tensor, out))};
}
}
size_t dtype_id = num_inputs == 5 ? 1: 2;
if (!context.input_is_none(dtype_id)){
size_t dtype_id = num_inputs == 5 ? 1 : 2;
if (!context.input_is_none(dtype_id)) {
auto pt_type = context.const_input<int64_t>(dtype_id);
FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::ones: ", pt_type);
auto dtype = TORCH_TO_OV_TYPE.at(pt_type);
Expand All @@ -149,14 +148,13 @@ OutputVector translate_ones_like(NodeContext& context) {
auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {1}));
auto input_shape = context.mark_node(std::make_shared<opset8::ShapeOf>(input));
auto filled_tensor = context.mark_node(std::make_shared<opset8::Broadcast>(value, input_shape));
if (context.get_input_size() == 6 && !context.input_is_none(1)){
if (context.get_input_size() == 6 && !context.input_is_none(1)) {
auto pt_type = context.const_input<int64_t>(1);
FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::ones_like: ", pt_type);
auto dtype = TORCH_TO_OV_TYPE.at(pt_type);
filled_tensor = context.mark_node(std::make_shared<opset8::Convert>(filled_tensor, dtype));
}
else {
auto out_dtype = context.input_is_none(2)? input : context.get_input(2);
} else {
auto out_dtype = context.input_is_none(2) ? input : context.get_input(2);
filled_tensor = context.mark_node(std::make_shared<opset8::ConvertLike>(filled_tensor, out_dtype));
}
return {filled_tensor};
Expand All @@ -167,7 +165,7 @@ OutputVector translate_new_ones(NodeContext& context) {
auto sizes = context.get_input(1);
auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {1}));
auto filled_tensor = context.mark_node(std::make_shared<opset8::Broadcast>(value, sizes));
if (context.get_input_size() == 6 && !context.input_is_none(2)){
if (context.get_input_size() == 6 && !context.input_is_none(2)) {
auto pt_type = context.const_input<int64_t>(2);
FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::new_zeros: ", pt_type);
auto dtype = TORCH_TO_OV_TYPE.at(pt_type);
Expand Down
8 changes: 6 additions & 2 deletions src/frontends/pytorch/src/op/upsample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ OutputVector translate_upsample2d(NodeContext& context, opset8::Interpolate::Int
auto size_mode = opset8::Interpolate::ShapeCalcMode::SIZES;
bool align_corners = false;
int scale_id = 2;
if (interpolate_mode == opset8::Interpolate::InterpolateMode::LINEAR_ONNX) {
if (interpolate_mode != opset8::Interpolate::InterpolateMode::NEAREST) {
scale_id = 3;
if (!context.input_is_none(2)) {
align_corners = context.const_input<bool>(2);
Expand All @@ -38,7 +38,7 @@ OutputVector translate_upsample2d(NodeContext& context, opset8::Interpolate::Int
auto attrs = opset8::Interpolate::InterpolateAttrs(interpolate_mode, size_mode, pad, pad);
attrs.coordinate_transformation_mode = opset8::Interpolate::CoordinateTransformMode::ASYMMETRIC;
attrs.nearest_mode = opset8::Interpolate::NearestMode::FLOOR;
if (attrs.mode == opset8::Interpolate::InterpolateMode::LINEAR_ONNX) {
if (attrs.mode != opset8::Interpolate::InterpolateMode::NEAREST) {
if (align_corners) {
attrs.coordinate_transformation_mode = opset8::Interpolate::CoordinateTransformMode::ALIGN_CORNERS;
}
Expand All @@ -54,6 +54,10 @@ OutputVector translate_upsample_nearest2d(NodeContext& context) {
return translate_upsample2d(context, opset8::Interpolate::InterpolateMode::NEAREST);
};

OutputVector translate_upsample_bicubic2d(NodeContext& context) {
return translate_upsample2d(context, opset8::Interpolate::InterpolateMode::CUBIC);
};

} // namespace op
} // namespace pytorch
} // namespace frontend
Expand Down
32 changes: 32 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ OP_CONVERTER(translate_sum);
OP_CONVERTER(translate_to);
OP_CONVERTER(translate_transpose);
OP_CONVERTER(translate_tuple_construct);
OP_CONVERTER(translate_upsample_bicubic2d);
OP_CONVERTER(translate_upsample_bilinear2d);
OP_CONVERTER(translate_upsample_nearest2d);
OP_CONVERTER(translate_var);
Expand All @@ -92,6 +93,10 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::_convolution", op::translate_convolution},
{"aten::_convolution_mode", op::translate_convolution_mode},
{"aten::abs", op::translate_1to1_match_1_inputs<opset8::Abs>},
{"aten::acos", op::translate_1to1_match_1_inputs<opset8::Acos>},
{"aten::acos_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Acos>>},
{"aten::acosh", op::translate_1to1_match_1_inputs<opset8::Acosh>},
{"aten::acosh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Acosh>>},
{"aten::adaptive_avg_pool2d", op::translate_1to1_match_2_inputs<opset8::AdaptiveAvgPool>},
{"aten::adaptive_avg_pool3d", op::translate_adaptive_avg_pool3d},
{"aten::adaptive_max_pool2d", op::translate_adaptive_max_pool2d},
Expand All @@ -100,18 +105,32 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::addcmul", op::translate_addcmul},
{"aten::addmm", op::translate_addmm},
{"aten::arange", op::translate_arange},
{"aten::asin", op::translate_1to1_match_1_inputs<opset8::Asin>},
{"aten::asin_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Asin>>},
{"aten::asinh", op::translate_1to1_match_1_inputs<opset8::Asinh>},
{"aten::asinh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Asinh>>},
{"aten::as_tensor", op::translate_as_tensor},
{"aten::atan", op::translate_1to1_match_1_inputs<opset8::Atan>},
{"aten::atan_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Atan>>},
{"aten::atanh", op::translate_1to1_match_1_inputs<opset8::Atanh>},
{"aten::atanh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Atanh>>},
{"aten::avg_pool2d", op::translate_avg_pool2d},
{"aten::batch_norm", op::translate_batch_norm},
// {"aten::cat", done as transformation},
{"aten::clamp", op::translate_clamp},
{"aten::clamp_min", op::translate_1to1_match_2_inputs<opset8::Maximum>},
{"aten::clamp_max", op::translate_1to1_match_2_inputs<opset8::Minimum>},
{"aten::ceil", op::translate_1to1_match_1_inputs<opset8::Ceiling>},
{"aten::ceil_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Ceiling>>},
{"aten::clone", op::skip_node}, // ignore clone operators that are inserted by PyTorch autograd
{"aten::contiguous", op::skip_node}, // In openvino how tensors are stored in memory is internal plugin detail,
// we assume all tensors are contiguous
{"aten::conv2d", op::translate_conv2d},
{"aten::convolution", op::translate_convolution},
{"aten::cos", op::translate_1to1_match_1_inputs<opset8::Cos>},
{"aten::cos_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Cos>>},
{"aten::cosh", op::translate_1to1_match_1_inputs<opset8::Cosh>},
{"aten::cosh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Cosh>>},
{"aten::cumsum", op::translate_1to1_match_2_inputs<opset8::CumSum>},
{"aten::dim", op::translate_dim},
{"aten::div", op::translate_div},
Expand All @@ -123,11 +142,14 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::expand", op::translate_expand},
{"aten::expand_as", op::translate_expand_as},
{"aten::flatten", op::translate_flatten},
{"aten::floor", op::translate_1to1_match_1_inputs<opset8::Floor>},
{"aten::floor_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Floor>>},
{"aten::floordiv", op::translate_floordiv},
{"aten::full", op::translate_full},
{"aten::full_like", op::translate_full_like},
{"aten::gelu", op::translate_gelu},
{"aten::group_norm", op::translate_group_norm},
{"aten::ge", op::translate_1to1_match_2_inputs<opset8::GreaterEqual>},
{"aten::gt", op::translate_1to1_match_2_inputs<opset8::Greater>},
{"aten::hardsigmoid", op::translate_1to1_match_1_inputs<opset8::HSigmoid>},
{"aten::hardswish", op::translate_1to1_match_1_inputs<opset8::HSwish>},
Expand All @@ -140,6 +162,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::leaky_relu", op::translate_1to1_match_2_inputs<opset8::PRelu>},
{"aten::leaky_relu_", op::inplace_op<op::translate_1to1_match_2_inputs<opset8::PRelu>>},
{"aten::linear", op::translate_linear},
{"aten::le", op::translate_1to1_match_2_inputs<opset8::LessEqual>},
{"aten::lt", op::translate_1to1_match_2_inputs<opset8::Less>},
{"aten::matmul", op::translate_1to1_match_2_inputs<opset8::MatMul>},
{"aten::masked_fill", op::translate_masked_fill},
Expand Down Expand Up @@ -178,6 +201,10 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::sigmoid", op::translate_1to1_match_1_inputs<opset8::Sigmoid>},
{"aten::silu", op::translate_1to1_match_1_inputs<opset8::Swish>},
{"aten::silu_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Swish>>},
{"aten::sin", op::translate_1to1_match_1_inputs<opset8::Sin>},
{"aten::sin_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Sin>>},
{"aten::sinh", op::translate_1to1_match_1_inputs<opset8::Sinh>},
{"aten::sinh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Sinh>>},
{"aten::size", op::translate_size},
{"aten::slice", op::translate_slice},
{"aten::softmax", op::translate_softmax},
Expand All @@ -186,13 +213,18 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::squeeze", op::translate_squeeze},
{"aten::sub", op::translate_sub},
{"aten::sum", op::translate_sum},
{"aten::tan", op::translate_1to1_match_1_inputs<opset8::Tan>},
{"aten::tan_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Tan>>},
{"aten::tanh", op::translate_1to1_match_1_inputs<opset8::Tanh>},
{"aten::tanh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Tanh>>},
{"aten::tensor", op::translate_as_tensor},
{"aten::type_as",
op::translate_1to1_match_2_inputs<opset8::ConvertLike>}, // TODO: overflow semantics is different
{"aten::to", op::translate_to},
{"aten::transpose", op::translate_transpose},
{"aten::unsqueeze", op::translate_1to1_match_2_inputs<opset8::Unsqueeze>},
{"aten::unsqueeze_", op::inplace_op<op::translate_1to1_match_2_inputs<opset8::Unsqueeze>>},
{"aten::upsample_bicubic2d", op::translate_upsample_bicubic2d},
{"aten::upsample_bilinear2d", op::translate_upsample_bilinear2d},
{"aten::upsample_nearest2d", op::translate_upsample_nearest2d},
{"aten::var", op::translate_var},
Expand Down
31 changes: 31 additions & 0 deletions tests/layer_tests/pytorch_tests/test_ceil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
from pytorch_layer_test_class import PytorchLayerTest


class TestCeil(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(1, 3, 224, 224).astype(np.float32),)

def create_model(self, inplace):
import torch

class aten_ceil(torch.nn.Module):
def __init__(self, inplace):
super(aten_ceil, self).__init__()
self.op = torch.ceil_ if inplace else torch.ceil

def forward(self, x):
return x, self.op(x)

ref_net = None

return aten_ceil(inplace), ref_net, "aten::ceil" if not inplace else "aten::ceil_"

@pytest.mark.parametrize("inplace", [False, True])
@pytest.mark.nightly
def test_ceil(self, inplace, ie_device, precision, ir_version):
self._test(*self.create_model(inplace), ie_device, precision, ir_version)
Loading

0 comments on commit 23088a9

Please sign in to comment.