Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aten::arange #71

Merged
merged 7 commits into from
Jan 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions src/frontends/pytorch/src/op/arange.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset8.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
OutputVector translate_arange(NodeContext& context) {
auto zero = context.mark_node(opset8::Constant::create(element::i32, Shape{}, {0}));
auto one = context.mark_node(opset8::Constant::create(element::i32, Shape{}, {1}));
auto dtype = element::f32;
bool dtype_applied = false;
int num_inputs = context.get_input_size();
ov::Output<Node> end;
ov::Output<Node> out_tensor;
ov::Output<Node> start = zero;
ov::Output<Node> step = one;

// aten::arange(Scalar end, tensor out)
if (num_inputs == 2) {
end = context.get_input(0);
out_tensor = context.input_is_none(1) ? end : context.get_input(1);
}
// # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out)
if (num_inputs == 4) {
start = context.get_input(0);
end = context.get_input(1);
step = context.get_input(2);
out_tensor = context.input_is_none(3) ? end : context.get_input(3);
}
// aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
if (num_inputs == 5) {
end = context.get_input(0);
out_tensor = end;
if (!context.input_is_none(1)) {
dtype = convert_dtype(context, 1);
dtype_applied = true;
}
}
// aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
if (num_inputs == 6) {
start = context.get_input(0);
end = context.get_input(1);
out_tensor = end;
if (!context.input_is_none(2)) {
dtype = convert_dtype(context, 2);
dtype_applied = true;
}
}
// aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory)
if (num_inputs == 7) {
start = context.get_input(0);
end = context.get_input(1);
step = context.get_input(2);
out_tensor = end;
if (!context.input_is_none(3)) {
dtype = convert_dtype(context, 2);
dtype_applied = true;
}
}
auto r_end = context.mark_node(std::make_shared<opset8::Convert>(end, dtype));
auto r_start = context.mark_node(std::make_shared<opset8::Convert>(start, dtype));
auto r_step = context.mark_node(std::make_shared<opset8::Convert>(step, dtype));
auto range = context.mark_node(std::make_shared<opset8::Range>(r_start, r_end, r_step, dtype));
if (!dtype_applied) {
range = context.mark_node(std::make_shared<opset8::ConvertLike>(range, out_tensor));
}
return {range};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
27 changes: 17 additions & 10 deletions src/frontends/pytorch/src/op/as_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,25 @@ namespace pytorch {
namespace op {

OutputVector translate_as_tensor(NodeContext& context) {
auto dtype_ext_node = context.get_input_from_visible_context(1).get_node_shared_ptr();
auto dtype_fw_node = std::dynamic_pointer_cast<PtFrameworkNode>(dtype_ext_node);
auto dtype = element::f32;
Output<Node> cast;
if (dtype_fw_node && dtype_fw_node->get_op_type() == "prim::dtype") {
auto type_input = dtype_fw_node->input_value(0);
cast = context.mark_node(std::make_shared<opset8::ConvertLike>(context.get_input(0), type_input));
} else if (const auto dtype_const = std::dynamic_pointer_cast<opset8::Constant>(dtype_ext_node)) {
auto pt_type = dtype_const->cast_vector<int64_t>()[0];
FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::as_tensor: ", pt_type);
auto dtype = TORCH_TO_OV_TYPE.at(pt_type);
cast = context.mark_node(std::make_shared<opset8::Convert>(context.get_input(0), dtype));
if (!context.input_is_none(1)) {
auto dtype_ext_node = context.get_input_from_visible_context(1).get_node_shared_ptr();
auto dtype_fw_node = std::dynamic_pointer_cast<PtFrameworkNode>(dtype_ext_node);
if (dtype_fw_node && dtype_fw_node->get_op_type() == "prim::dtype") {
auto type_input = dtype_fw_node->input_value(0);
return {context.mark_node(std::make_shared<opset8::ConvertLike>(context.get_input(0), type_input))};
}
if (auto dtype_const = std::dynamic_pointer_cast<opset8::Constant>(dtype_ext_node)) {
auto pt_type = dtype_const->cast_vector<int64_t>()[0];
FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type),
"Unknown type in aten::as_tensor: ",
pt_type);
dtype = TORCH_TO_OV_TYPE.at(pt_type);
}
}
cast = context.mark_node(std::make_shared<opset8::Convert>(context.get_input(0), dtype));

// Input with index 2 is device, we skip this input
return {cast};
};
Expand Down
105 changes: 87 additions & 18 deletions src/frontends/pytorch/src/op/full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,70 +11,139 @@ namespace frontend {
namespace pytorch {
namespace op {

ov::Output<Node> base_translate_full(NodeContext& context, ov::Output<Node> sizes, ov::Output<Node> value) {
return context.mark_node(std::make_shared<opset8::Broadcast>(value, sizes));
}

ov::Output<Node> base_translate_full_with_convert(NodeContext& context,
ov::Output<Node> sizes,
ov::Output<Node> value,
size_t dtype_id) {
auto filled_tensor = base_translate_full(context, sizes, value);
if (!context.input_is_none(dtype_id)) {
auto dtype = convert_dtype(context, dtype_id);
filled_tensor = context.mark_node(std::make_shared<opset8::Convert>(filled_tensor, dtype));
}
return filled_tensor;
}

ov::Output<Node> base_translate_full_with_convertlike(NodeContext& context,
ov::Output<Node> sizes,
ov::Output<Node> value,
ov::Output<Node> out) {
auto filled_tensor = base_translate_full(context, sizes, value);
return context.mark_node(std::make_shared<opset8::ConvertLike>(filled_tensor, out));
}

OutputVector translate_full(NodeContext& context) {
auto sizes = context.get_input(0);
auto value = context.get_input(1);
return {context.mark_node(std::make_shared<opset8::Broadcast>(value, sizes))};
int num_inputs = context.get_input_size();
if (num_inputs < 6) {
size_t out_id = num_inputs == 3 ? 2 : 3;
if (!context.input_is_none(out_id)) {
auto out = context.get_input(out_id);
return {base_translate_full_with_convertlike(context, sizes, value, out)};
}
return {base_translate_full(context, sizes, value)};
}
size_t dtype_id = num_inputs == 6 ? 2 : 3;
return {base_translate_full_with_convert(context, sizes, value, dtype_id)};
};

OutputVector translate_full_like(NodeContext& context) {
auto input = context.get_input(0);
auto value = context.get_input(1);
auto input_shape = context.mark_node(std::make_shared<opset8::ShapeOf>(input));
auto filled_tensor = context.mark_node(std::make_shared<opset8::Broadcast>(value, input_shape));
return {filled_tensor};
auto sizes = context.mark_node(std::make_shared<opset8::ShapeOf>(input));
if (context.get_input_size() == 7) {
return {base_translate_full_with_convert(context, sizes, value, 2)};
}
auto out = context.input_is_none(3) ? input : context.get_input(3);
return {base_translate_full_with_convertlike(context, sizes, value, out)};
};

OutputVector translate_new_full(NodeContext& context) {
auto input = context.get_input(0);
auto sizes = context.get_input(1);
auto value = context.get_input(2);
auto filled_tensor = context.mark_node(std::make_shared<opset8::Broadcast>(value, sizes));
return {context.mark_node(std::make_shared<opset8::ConvertLike>(filled_tensor, input))};
if (context.get_input_size() == 7) {
return {base_translate_full_with_convert(context, sizes, value, 3)};
}
return {base_translate_full_with_convertlike(context, sizes, value, input)};
};

OutputVector translate_zeros(NodeContext& context) {
auto sizes = context.get_input(0);
auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {0}));
return {context.mark_node(std::make_shared<opset8::Broadcast>(value, sizes))};
int num_inputs = context.get_input_size();
if (num_inputs < 5) {
size_t out_id = num_inputs == 2 ? 1 : 2;
if (!context.input_is_none(out_id)) {
auto out = context.get_input(out_id);
return {base_translate_full_with_convertlike(context, sizes, value, out)};
}
return {base_translate_full(context, sizes, value)};
}
size_t dtype_id = num_inputs == 5 ? 1 : 2;
return {base_translate_full_with_convert(context, sizes, value, dtype_id)};
};

OutputVector translate_zeros_like(NodeContext& context) {
auto input = context.get_input(0);
auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {0}));
auto input_shape = context.mark_node(std::make_shared<opset8::ShapeOf>(input));
auto filled_tensor = context.mark_node(std::make_shared<opset8::Broadcast>(value, input_shape));
return {filled_tensor};
auto sizes = context.mark_node(std::make_shared<opset8::ShapeOf>(input));
if (context.get_input_size() == 6) {
return {base_translate_full_with_convert(context, sizes, value, 1)};
}
auto out = context.input_is_none(2) ? input : context.get_input(2);
return {base_translate_full_with_convertlike(context, sizes, value, out)};
};

OutputVector translate_new_zeros(NodeContext& context) {
auto input = context.get_input(0);
auto sizes = context.get_input(1);
auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {0}));
auto filled_tensor = context.mark_node(std::make_shared<opset8::Broadcast>(value, sizes));
return {context.mark_node(std::make_shared<opset8::ConvertLike>(filled_tensor, input))};
if (context.get_input_size() == 6) {
return {base_translate_full_with_convert(context, sizes, value, 2)};
}
return {base_translate_full_with_convertlike(context, sizes, value, input)};
};

OutputVector translate_ones(NodeContext& context) {
auto sizes = context.get_input(0);
auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {1}));
return {context.mark_node(std::make_shared<opset8::Broadcast>(value, sizes))};
int num_inputs = context.get_input_size();
if (num_inputs < 5) {
size_t out_id = num_inputs == 2 ? 1 : 2;
if (!context.input_is_none(out_id)) {
auto out = context.get_input(out_id);
return {base_translate_full_with_convertlike(context, sizes, value, out)};
}
return {base_translate_full(context, sizes, value)};
}
size_t dtype_id = num_inputs == 5 ? 1 : 2;
return {base_translate_full_with_convert(context, sizes, value, dtype_id)};
};

OutputVector translate_ones_like(NodeContext& context) {
auto input = context.get_input(0);
auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {1}));
auto input_shape = context.mark_node(std::make_shared<opset8::ShapeOf>(input));
auto filled_tensor = context.mark_node(std::make_shared<opset8::Broadcast>(value, input_shape));
return {filled_tensor};
auto sizes = context.mark_node(std::make_shared<opset8::ShapeOf>(input));
if (context.get_input_size() == 6) {
return {base_translate_full_with_convert(context, sizes, value, 1)};
}
auto out = context.input_is_none(2) ? input : context.get_input(2);
return {base_translate_full_with_convertlike(context, sizes, value, out)};
};

OutputVector translate_new_ones(NodeContext& context) {
auto input = context.get_input(0);
auto sizes = context.get_input(1);
auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {1}));
auto filled_tensor = context.mark_node(std::make_shared<opset8::Broadcast>(value, sizes));
return {context.mark_node(std::make_shared<opset8::ConvertLike>(filled_tensor, input))};
if (context.get_input_size() == 6) {
return {base_translate_full_with_convert(context, sizes, value, 2)};
}
return {base_translate_full_with_convertlike(context, sizes, value, input)};
};

} // namespace op
Expand Down
3 changes: 3 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ OP_CONVERTER(translate_adaptive_max_pool2d);
OP_CONVERTER(translate_add);
OP_CONVERTER(translate_addcmul);
OP_CONVERTER(translate_addmm);
OP_CONVERTER(translate_arange);
OP_CONVERTER(translate_as_tensor);
OP_CONVERTER(translate_avg_pool2d);
OP_CONVERTER(translate_batch_norm);
Expand Down Expand Up @@ -104,6 +105,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::add_", op::inplace_op<op::translate_add>},
{"aten::addcmul", op::translate_addcmul},
{"aten::addmm", op::translate_addmm},
{"aten::arange", op::translate_arange},
{"aten::asin", op::translate_1to1_match_1_inputs<opset8::Asin>},
{"aten::asin_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Asin>>},
{"aten::asinh", op::translate_1to1_match_1_inputs<opset8::Asinh>},
Expand Down Expand Up @@ -217,6 +219,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::tan_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Tan>>},
{"aten::tanh", op::translate_1to1_match_1_inputs<opset8::Tanh>},
{"aten::tanh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Tanh>>},
{"aten::tensor", op::translate_as_tensor},
{"aten::type_as",
op::translate_1to1_match_2_inputs<opset8::ConvertLike>}, // TODO: overflow semantics is different
{"aten::to", op::translate_to},
Expand Down
6 changes: 6 additions & 0 deletions src/frontends/pytorch/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ std::shared_ptr<Node> numel(NodeContext& context, size_t input_id) {
return context.mark_node(std::make_shared<opset8::ReduceProd>(input_shape, axes, false));
};

ov::element::Type convert_dtype(NodeContext& context, size_t input_id) {
auto pt_type = context.const_input<int64_t>(input_id);
FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type: ", pt_type);
return TORCH_TO_OV_TYPE.at(pt_type);
};

std::shared_ptr<Node> concat_list_construct(std::shared_ptr<Node> input) {
if (auto list_construct = cast_fw_node(input, "prim::ListConstruct")) {
auto list_inputs = list_construct->input_values();
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ std::shared_ptr<Node> get_axes_range(NodeContext& context, size_t input_id);

std::shared_ptr<Node> numel(NodeContext& context, size_t input_id);

ov::element::Type convert_dtype(NodeContext& context, size_t input_id);

std::shared_ptr<Node> concat_list_construct(std::shared_ptr<Node> input);

std::shared_ptr<ov::Model> convert_pytorch_model(std::shared_ptr<Decoder> pytorch_model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _test(self, model, ref_net, kind, ie_device, precision, ir_version, infer_ti
assert ov_tensor == fw_tensor
assert type(fw_tensor) == type(ov_tensor)
continue
assert torch.tensor(np.array(ov_tensor)).dtype == fw_tensor.dtype
assert torch.tensor(np.array(ov_tensor)).dtype == fw_tensor.dtype, f"dtype validation failed: {torch.tensor(np.array(ov_tensor)).dtype} != {fw_tensor.dtype}"

if 'custom_eps' in kwargs and kwargs['custom_eps'] is not None:
custom_eps = kwargs['custom_eps']
Expand Down
Loading