Skip to content

Commit

Permalink
Added more needed ops for toorchbench models
Browse files Browse the repository at this point in the history
  • Loading branch information
ynimmaga committed Mar 29, 2024
1 parent 8913c1f commit 2e948f2
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from openvino.runtime import Core, Type, PartialShape
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
from openvino.frontend.pytorch.torchdynamo import decompositions
from openvino.frontend.pytorch.torchdynamo.decompositions import get_aot_decomposition_list
from openvino.frontend.pytorch.torchdynamo.decompositions import get_aot_decomposition_list, get_inf_decomposition_list
from openvino.frontend.pytorch.torchdynamo.partition import Partitioner
from openvino.frontend.pytorch.torchdynamo.execute import execute, execute_cached
from openvino.frontend.pytorch.torchdynamo.compile import cached_model_name, openvino_compile_cached_model
Expand Down Expand Up @@ -146,7 +146,7 @@ def _call(*args):
example_inputs.reverse()

from torch._subclasses.fake_tensor import FakeTensorMode
decompositions = _get_decompositions(options)
decompositions = _get_decompositions(options) + get_inf_decomposition_list()
if (_get_aot_autograd(options)):
decompositions = decompositions + get_aot_decomposition_list()
with FakeTensorMode(allow_non_fake_inputs=True):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,6 @@ def get_aot_decomposition_list():
torch.ops.aten.native_layer_norm.default,
torch.ops.aten.native_layer_norm_backward.default,
torch.ops.aten.slice_backward.default])

def get_inf_decomposition_list():
return ([torch.ops.aten.nll_loss_forward.default])
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def __init__(self, options):
"torch.ops.aten.mm.default": None,
"torch.ops.aten.mul.Scalar": None,
"torch.ops.aten.mul.Tensor": None,
"torch.ops.aten.mul_.Tensor": None,
"torch.ops.aten.native_batch_norm.default": None,
"torch.ops.aten.native_dropout.default": None,
"torch.ops.aten.native_group_norm.default": None,
Expand All @@ -191,6 +192,7 @@ def __init__(self, options):
"torch.ops.aten.pow.Tensor_Scalar": None,
"torch.ops.aten.pow.Tensor_Tensor": None,
"torch.ops.aten.rand.default": None,
"torch.ops.aten.reflection_pad2d.default": None,
"torch.ops.aten.reciprocal.default": None,
"torch.ops.aten.relu.default": None,
"torch.ops.aten.relu_.default": None,
Expand All @@ -205,6 +207,7 @@ def __init__(self, options):
"torch.ops.aten.select.int": None,
"torch.ops.aten.select_scatter.default": None,
"torch.ops.aten.sigmoid.default": None,
"torch.ops.aten.sigmoid_.default": None,
"torch.ops.aten.sign.default": None,
"torch.ops.aten.silu.default": None,
"torch.ops.aten.silu_.default": None,
Expand All @@ -219,6 +222,7 @@ def __init__(self, options):
"torch.ops.aten.squeeze.dim": None,
"torch.ops.aten.squeeze.dims": None,
"torch.ops.aten.stack.default": None,
"torch.ops.aten.std.correction": None,
"torch.ops.aten.sub.default": None,
"torch.ops.aten.sub.Tensor": None,
"torch.ops.aten.sum.default": None,
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op/cat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ OutputVector translate_stack_fx(const NodeContext& context) {
std::deque<Output<Node>> list_elems;
auto num_elements = context.get_input_size();
for (size_t i = 0; i < num_elements - 1; i++) {
if (context.get_input(i).get_partial_shape().rank() == 1)
dim = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
auto stack_input =
context.mark_node(std::make_shared<v0::Unsqueeze>(context.get_input(static_cast<int>(i)), dim));
list_elems.push_back(stack_input);
Expand Down
7 changes: 6 additions & 1 deletion src/frontends/pytorch/src/op/split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//#include <climits>

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/op/variadic_split.hpp"
#include "utils.hpp"
Expand Down Expand Up @@ -55,7 +56,11 @@ OutputVector translate_unbind_int_fx(const NodeContext& context) {
auto num_splits = static_cast<int>(shape[dim_val]);
auto chunk = context.mark_node(std::make_shared<v1::Split>(input, dim, num_splits));

return {context.mark_node(make_list_construct(chunk->outputs()))};
ov::OutputVector out_vec;
for (auto& out : chunk->outputs())
out_vec.push_back(std::make_shared<v0::Squeeze>(out, dim));

return {context.mark_node(make_list_construct(out_vec))};
}

OutputVector translate_split_with_sizes_fx(const NodeContext& context) {
Expand Down
29 changes: 28 additions & 1 deletion src/frontends/pytorch/src/op/var_mean.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,33 @@ OutputVector translate_var(const NodeContext& context) {
return {res[0]};
}

OutputVector translate_std_fx(const NodeContext& context) {
num_inputs_check(context, 1, 2);
auto data = context.get_input(0);
ov::Output<ov::Node> axes;

if (!context.input_is_none(1)) {
axes = context.get_input(1);
}
int32_t correction = 0;
if (context.has_attribute("correction")) {
auto correction_node = context.get_attribute<Output<Node>>("correction");
auto const_node = as_type_ptr<v0::Constant>(correction_node.get_node_shared_ptr());
PYTORCH_OP_CONVERSION_CHECK(const_node, "correction must be const.");
correction = const_node->cast_vector<int32_t>()[0];
}
bool keepdim = false;
if (context.has_attribute("keepdim")) {
auto keepdim_node = context.get_attribute<Output<Node>>("keepdim");
auto const_node = as_type_ptr<v0::Constant>(keepdim_node.get_node_shared_ptr());
PYTORCH_OP_CONVERSION_CHECK(const_node, "keepdim must be const.");
keepdim = const_node->cast_vector<bool>()[0];
}
auto res = translate_var_mean_common(context, data, axes, correction, keepdim);

return {context.mark_node(std::make_shared<v0::Sqrt>(res[0]))};
}

OutputVector translate_std(const NodeContext& context) {
auto res = translate_var_mean(context);
auto var = res[0];
Expand All @@ -160,4 +187,4 @@ OutputVector translate_std_mean(const NodeContext& context) {
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
} // namespace ov
7 changes: 7 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ OP_CONVERTER(translate_split_with_sizes_fx);
OP_CONVERTER(translate_stack_fx);
OP_CONVERTER(translate_sub_fx);
OP_CONVERTER(translate_sum_fx);
OP_CONVERTER(translate_std_fx);
OP_CONVERTER(translate_topk_fx);
OP_CONVERTER(translate_to_fx);
OP_CONVERTER(translate_transpose_fx);
Expand Down Expand Up @@ -798,6 +799,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.elu_.default", op::inplace_op<op::translate_elu>},
{"aten.embedding.default", op::translate_embedding},
{"aten.empty.memory_format", op::translate_empty},
{"aten.empty_like.default", op::translate_empty_like},
{"aten.eq.Scalar", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
{"aten.eq.Tensor", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
{"aten.erf.default", op::translate_erf},
Expand Down Expand Up @@ -849,6 +851,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.log1p.default", op::translate_log1p},
{"aten.log2.default", op::translate_log2},
{"aten.logical_and.default", op::translate_and},
{"aten.logical_not.default", op::translate_not},
{"aten.logsumexp.default", op::translate_logsumexp},
{"aten.lt.Scalar", op::translate_1to1_match_2_inputs_align_types<opset10::Less>},
{"aten.lt.Tensor", op::translate_1to1_match_2_inputs_align_types<opset10::Less>},
Expand All @@ -870,6 +873,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.mm.default", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten.mul.Scalar", op::translate_mul},
{"aten.mul.Tensor", op::translate_mul},
{"aten.mul_.Tensor", op::translate_mul},
{"aten.native_batch_norm.default", op::translate_batch_norm_legit_fx},
{"aten.native_dropout.default", op::skip_node},
{"aten.native_group_norm.default", op::translate_group_norm_fx},
Expand Down Expand Up @@ -908,6 +912,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.select.int", op::translate_select},
{"aten.select_scatter.default", op::translate_select_scatter_fx},
{"aten.sigmoid.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sigmoid>},
{"aten.sigmoid_.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sigmoid>},
{"aten.sign.default", op::translate_sign},
{"aten.silu.default", op::translate_1to1_match_1_inputs<opset10::Swish>},
{"aten.silu_.default", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Swish>>},
Expand All @@ -922,6 +927,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.squeeze.dim", op::translate_squeeze},
{"aten.squeeze.dims", op::translate_squeeze},
{"aten.stack.default", op::translate_stack_fx},
{"aten.std.correction", op::translate_std_fx},
{"aten.sub.default", op::translate_sub_fx},
{"aten.sub.Tensor", op::translate_sub_fx},
{"aten.sum.default", op::translate_sum_fx},
Expand All @@ -935,6 +941,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.unbind.int", op::translate_unbind_int_fx},
{"aten.unfold.default", op::translate_unfold},
{"aten.unsqueeze.default", op::translate_1to1_match_2_inputs<opset10::Unsqueeze>},
{"aten.upsample_bilinear2d.default", op::translate_upsample_bilinear2d},
{"aten.upsample_nearest2d.default", op::translate_upsample_nearest2d},
{"aten.var.correction", op::translate_var_fx},
{"aten.var_mean.correction", op::translate_var_mean_fx},
Expand Down

0 comments on commit 2e948f2

Please sign in to comment.