Skip to content

Commit

Permalink
Merge pull request #71 from eaidova/ea/arange
Browse files Browse the repository at this point in the history
aten::arange
  • Loading branch information
slyalin authored Jan 3, 2023
2 parents 019bd37 + 01496ca commit 8969c80
Show file tree
Hide file tree
Showing 9 changed files with 638 additions and 43 deletions.
79 changes: 79 additions & 0 deletions src/frontends/pytorch/src/op/arange.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset8.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
OutputVector translate_arange(NodeContext& context) {
auto zero = context.mark_node(opset8::Constant::create(element::i32, Shape{}, {0}));
auto one = context.mark_node(opset8::Constant::create(element::i32, Shape{}, {1}));
auto dtype = element::f32;
bool dtype_applied = false;
int num_inputs = context.get_input_size();
ov::Output<Node> end;
ov::Output<Node> out_tensor;
ov::Output<Node> start = zero;
ov::Output<Node> step = one;

// aten::arange(Scalar end, tensor out)
if (num_inputs == 2) {
end = context.get_input(0);
out_tensor = context.input_is_none(1) ? end : context.get_input(1);
}
// # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out)
if (num_inputs == 4) {
start = context.get_input(0);
end = context.get_input(1);
step = context.get_input(2);
out_tensor = context.input_is_none(3) ? end : context.get_input(3);
}
// aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
if (num_inputs == 5) {
end = context.get_input(0);
out_tensor = end;
if (!context.input_is_none(1)) {
dtype = convert_dtype(context, 1);
dtype_applied = true;
}
}
// aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
if (num_inputs == 6) {
start = context.get_input(0);
end = context.get_input(1);
out_tensor = end;
if (!context.input_is_none(2)) {
dtype = convert_dtype(context, 2);
dtype_applied = true;
}
}
// aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory)
if (num_inputs == 7) {
start = context.get_input(0);
end = context.get_input(1);
step = context.get_input(2);
out_tensor = end;
if (!context.input_is_none(3)) {
dtype = convert_dtype(context, 2);
dtype_applied = true;
}
}
auto r_end = context.mark_node(std::make_shared<opset8::Convert>(end, dtype));
auto r_start = context.mark_node(std::make_shared<opset8::Convert>(start, dtype));
auto r_step = context.mark_node(std::make_shared<opset8::Convert>(step, dtype));
auto range = context.mark_node(std::make_shared<opset8::Range>(r_start, r_end, r_step, dtype));
if (!dtype_applied) {
range = context.mark_node(std::make_shared<opset8::ConvertLike>(range, out_tensor));
}
return {range};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
27 changes: 17 additions & 10 deletions src/frontends/pytorch/src/op/as_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,25 @@ namespace pytorch {
namespace op {

OutputVector translate_as_tensor(NodeContext& context) {
auto dtype_ext_node = context.get_input_from_visible_context(1).get_node_shared_ptr();
auto dtype_fw_node = std::dynamic_pointer_cast<PtFrameworkNode>(dtype_ext_node);
auto dtype = element::f32;
Output<Node> cast;
if (dtype_fw_node && dtype_fw_node->get_op_type() == "prim::dtype") {
auto type_input = dtype_fw_node->input_value(0);
cast = context.mark_node(std::make_shared<opset8::ConvertLike>(context.get_input(0), type_input));
} else if (const auto dtype_const = std::dynamic_pointer_cast<opset8::Constant>(dtype_ext_node)) {
auto pt_type = dtype_const->cast_vector<int64_t>()[0];
FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type in aten::as_tensor: ", pt_type);
auto dtype = TORCH_TO_OV_TYPE.at(pt_type);
cast = context.mark_node(std::make_shared<opset8::Convert>(context.get_input(0), dtype));
if (!context.input_is_none(1)) {
auto dtype_ext_node = context.get_input_from_visible_context(1).get_node_shared_ptr();
auto dtype_fw_node = std::dynamic_pointer_cast<PtFrameworkNode>(dtype_ext_node);
if (dtype_fw_node && dtype_fw_node->get_op_type() == "prim::dtype") {
auto type_input = dtype_fw_node->input_value(0);
return {context.mark_node(std::make_shared<opset8::ConvertLike>(context.get_input(0), type_input))};
}
if (auto dtype_const = std::dynamic_pointer_cast<opset8::Constant>(dtype_ext_node)) {
auto pt_type = dtype_const->cast_vector<int64_t>()[0];
FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type),
"Unknown type in aten::as_tensor: ",
pt_type);
dtype = TORCH_TO_OV_TYPE.at(pt_type);
}
}
cast = context.mark_node(std::make_shared<opset8::Convert>(context.get_input(0), dtype));

// Input with index 2 is device, we skip this input
return {cast};
};
Expand Down
105 changes: 87 additions & 18 deletions src/frontends/pytorch/src/op/full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,70 +11,139 @@ namespace frontend {
namespace pytorch {
namespace op {

ov::Output<Node> base_translate_full(NodeContext& context, ov::Output<Node> sizes, ov::Output<Node> value) {
return context.mark_node(std::make_shared<opset8::Broadcast>(value, sizes));
}

ov::Output<Node> base_translate_full_with_convert(NodeContext& context,
ov::Output<Node> sizes,
ov::Output<Node> value,
size_t dtype_id) {
auto filled_tensor = base_translate_full(context, sizes, value);
if (!context.input_is_none(dtype_id)) {
auto dtype = convert_dtype(context, dtype_id);
filled_tensor = context.mark_node(std::make_shared<opset8::Convert>(filled_tensor, dtype));
}
return filled_tensor;
}

ov::Output<Node> base_translate_full_with_convertlike(NodeContext& context,
ov::Output<Node> sizes,
ov::Output<Node> value,
ov::Output<Node> out) {
auto filled_tensor = base_translate_full(context, sizes, value);
return context.mark_node(std::make_shared<opset8::ConvertLike>(filled_tensor, out));
}

OutputVector translate_full(NodeContext& context) {
auto sizes = context.get_input(0);
auto value = context.get_input(1);
return {context.mark_node(std::make_shared<opset8::Broadcast>(value, sizes))};
int num_inputs = context.get_input_size();
if (num_inputs < 6) {
size_t out_id = num_inputs == 3 ? 2 : 3;
if (!context.input_is_none(out_id)) {
auto out = context.get_input(out_id);
return {base_translate_full_with_convertlike(context, sizes, value, out)};
}
return {base_translate_full(context, sizes, value)};
}
size_t dtype_id = num_inputs == 6 ? 2 : 3;
return {base_translate_full_with_convert(context, sizes, value, dtype_id)};
};

OutputVector translate_full_like(NodeContext& context) {
auto input = context.get_input(0);
auto value = context.get_input(1);
auto input_shape = context.mark_node(std::make_shared<opset8::ShapeOf>(input));
auto filled_tensor = context.mark_node(std::make_shared<opset8::Broadcast>(value, input_shape));
return {filled_tensor};
auto sizes = context.mark_node(std::make_shared<opset8::ShapeOf>(input));
if (context.get_input_size() == 7) {
return {base_translate_full_with_convert(context, sizes, value, 2)};
}
auto out = context.input_is_none(3) ? input : context.get_input(3);
return {base_translate_full_with_convertlike(context, sizes, value, out)};
};

OutputVector translate_new_full(NodeContext& context) {
auto input = context.get_input(0);
auto sizes = context.get_input(1);
auto value = context.get_input(2);
auto filled_tensor = context.mark_node(std::make_shared<opset8::Broadcast>(value, sizes));
return {context.mark_node(std::make_shared<opset8::ConvertLike>(filled_tensor, input))};
if (context.get_input_size() == 7) {
return {base_translate_full_with_convert(context, sizes, value, 3)};
}
return {base_translate_full_with_convertlike(context, sizes, value, input)};
};

OutputVector translate_zeros(NodeContext& context) {
auto sizes = context.get_input(0);
auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {0}));
return {context.mark_node(std::make_shared<opset8::Broadcast>(value, sizes))};
int num_inputs = context.get_input_size();
if (num_inputs < 5) {
size_t out_id = num_inputs == 2 ? 1 : 2;
if (!context.input_is_none(out_id)) {
auto out = context.get_input(out_id);
return {base_translate_full_with_convertlike(context, sizes, value, out)};
}
return {base_translate_full(context, sizes, value)};
}
size_t dtype_id = num_inputs == 5 ? 1 : 2;
return {base_translate_full_with_convert(context, sizes, value, dtype_id)};
};

OutputVector translate_zeros_like(NodeContext& context) {
auto input = context.get_input(0);
auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {0}));
auto input_shape = context.mark_node(std::make_shared<opset8::ShapeOf>(input));
auto filled_tensor = context.mark_node(std::make_shared<opset8::Broadcast>(value, input_shape));
return {filled_tensor};
auto sizes = context.mark_node(std::make_shared<opset8::ShapeOf>(input));
if (context.get_input_size() == 6) {
return {base_translate_full_with_convert(context, sizes, value, 1)};
}
auto out = context.input_is_none(2) ? input : context.get_input(2);
return {base_translate_full_with_convertlike(context, sizes, value, out)};
};

OutputVector translate_new_zeros(NodeContext& context) {
auto input = context.get_input(0);
auto sizes = context.get_input(1);
auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {0}));
auto filled_tensor = context.mark_node(std::make_shared<opset8::Broadcast>(value, sizes));
return {context.mark_node(std::make_shared<opset8::ConvertLike>(filled_tensor, input))};
if (context.get_input_size() == 6) {
return {base_translate_full_with_convert(context, sizes, value, 2)};
}
return {base_translate_full_with_convertlike(context, sizes, value, input)};
};

OutputVector translate_ones(NodeContext& context) {
auto sizes = context.get_input(0);
auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {1}));
return {context.mark_node(std::make_shared<opset8::Broadcast>(value, sizes))};
int num_inputs = context.get_input_size();
if (num_inputs < 5) {
size_t out_id = num_inputs == 2 ? 1 : 2;
if (!context.input_is_none(out_id)) {
auto out = context.get_input(out_id);
return {base_translate_full_with_convertlike(context, sizes, value, out)};
}
return {base_translate_full(context, sizes, value)};
}
size_t dtype_id = num_inputs == 5 ? 1 : 2;
return {base_translate_full_with_convert(context, sizes, value, dtype_id)};
};

OutputVector translate_ones_like(NodeContext& context) {
auto input = context.get_input(0);
auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {1}));
auto input_shape = context.mark_node(std::make_shared<opset8::ShapeOf>(input));
auto filled_tensor = context.mark_node(std::make_shared<opset8::Broadcast>(value, input_shape));
return {filled_tensor};
auto sizes = context.mark_node(std::make_shared<opset8::ShapeOf>(input));
if (context.get_input_size() == 6) {
return {base_translate_full_with_convert(context, sizes, value, 1)};
}
auto out = context.input_is_none(2) ? input : context.get_input(2);
return {base_translate_full_with_convertlike(context, sizes, value, out)};
};

OutputVector translate_new_ones(NodeContext& context) {
auto input = context.get_input(0);
auto sizes = context.get_input(1);
auto value = context.mark_node(opset8::Constant::create(element::f32, Shape{}, {1}));
auto filled_tensor = context.mark_node(std::make_shared<opset8::Broadcast>(value, sizes));
return {context.mark_node(std::make_shared<opset8::ConvertLike>(filled_tensor, input))};
if (context.get_input_size() == 6) {
return {base_translate_full_with_convert(context, sizes, value, 2)};
}
return {base_translate_full_with_convertlike(context, sizes, value, input)};
};

} // namespace op
Expand Down
3 changes: 3 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ OP_CONVERTER(translate_adaptive_max_pool2d);
OP_CONVERTER(translate_add);
OP_CONVERTER(translate_addcmul);
OP_CONVERTER(translate_addmm);
OP_CONVERTER(translate_arange);
OP_CONVERTER(translate_as_tensor);
OP_CONVERTER(translate_avg_pool2d);
OP_CONVERTER(translate_batch_norm);
Expand Down Expand Up @@ -106,6 +107,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::add_", op::inplace_op<op::translate_add>},
{"aten::addcmul", op::translate_addcmul},
{"aten::addmm", op::translate_addmm},
{"aten::arange", op::translate_arange},
{"aten::asin", op::translate_1to1_match_1_inputs<opset8::Asin>},
{"aten::asin_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Asin>>},
{"aten::asinh", op::translate_1to1_match_1_inputs<opset8::Asinh>},
Expand Down Expand Up @@ -220,6 +222,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::tan_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Tan>>},
{"aten::tanh", op::translate_1to1_match_1_inputs<opset8::Tanh>},
{"aten::tanh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset8::Tanh>>},
{"aten::tensor", op::translate_as_tensor},
{"aten::type_as",
op::translate_1to1_match_2_inputs<opset8::ConvertLike>}, // TODO: overflow semantics is different
{"aten::to", op::translate_to},
Expand Down
6 changes: 6 additions & 0 deletions src/frontends/pytorch/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ std::shared_ptr<Node> numel(NodeContext& context, size_t input_id) {
return context.mark_node(std::make_shared<opset8::ReduceProd>(input_shape, axes, false));
};

ov::element::Type convert_dtype(NodeContext& context, size_t input_id) {
auto pt_type = context.const_input<int64_t>(input_id);
FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type: ", pt_type);
return TORCH_TO_OV_TYPE.at(pt_type);
};

std::shared_ptr<Node> concat_list_construct(std::shared_ptr<Node> input) {
if (auto list_construct = cast_fw_node(input, "prim::ListConstruct")) {
auto list_inputs = list_construct->input_values();
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ std::shared_ptr<Node> get_axes_range(NodeContext& context, size_t input_id);

std::shared_ptr<Node> numel(NodeContext& context, size_t input_id);

ov::element::Type convert_dtype(NodeContext& context, size_t input_id);

std::shared_ptr<Node> concat_list_construct(std::shared_ptr<Node> input);

std::shared_ptr<ov::Model> convert_pytorch_model(std::shared_ptr<Decoder> pytorch_model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _test(self, model, ref_net, kind, ie_device, precision, ir_version, infer_ti
assert ov_tensor == fw_tensor
assert type(fw_tensor) == type(ov_tensor)
continue
assert torch.tensor(np.array(ov_tensor)).dtype == fw_tensor.dtype
assert torch.tensor(np.array(ov_tensor)).dtype == fw_tensor.dtype, f"dtype validation failed: {torch.tensor(np.array(ov_tensor)).dtype} != {fw_tensor.dtype}"

if 'custom_eps' in kwargs and kwargs['custom_eps'] is not None:
custom_eps = kwargs['custom_eps']
Expand Down
Loading

0 comments on commit 8969c80

Please sign in to comment.