Skip to content

Commit

Permalink
Merge pull request #15 from mvafin/mvafin/pt_fe/flatten
Browse files Browse the repository at this point in the history
Fix flatten and types bug
  • Loading branch information
slyalin authored Oct 5, 2022
2 parents 63b19eb + d4e2e42 commit 89281ad
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

pt_to_ov_type_map = {
'float': OVType.f32,
'int': OVType.i64,
'int': OVType.i32,
'torch.float32': OVType.f32,
'torch.int32': OVType.i32
}
Expand All @@ -52,6 +52,9 @@ def __init__ (self, pt_module):
#print(pt_module)
#exit()

def free_decoders(self):
decoders.clear()

def inputs (self):
return [x.unique() for x in self.pt_module.inputs()]

Expand Down
5 changes: 3 additions & 2 deletions src/frontends/pytorch/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ std::shared_ptr<Model> FrontEnd::convert_partially(const ov::frontend::InputMode
} else {
std::cout << "[ WARNING ] Couldn't remove parameter[0] in converted Pytorch model\n";
}
}
apply_pytorch_conversion_transforms(model);
}
return model;
} catch (const std::runtime_error& e) {
std::cerr << "[ ERROR ] Unexpected error while converting pytorch model: " << e.what() << "\n";
Expand All @@ -85,6 +84,8 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<ov::frontend::pytorch::pass::PrimListUnpackReplacer>();

manager.run_passes(model);

apply_pytorch_conversion_transforms(model);
}

void FrontEnd::add_extension(const std::shared_ptr<ov::Extension>& extension) {
Expand Down
64 changes: 64 additions & 0 deletions src/frontends/pytorch/src/op/flatten.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset8.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_flatten(NodeContext& context) {
auto start_dim = context.const_input<int64_t>(1);
auto end_dim = context.const_input<int64_t>(2);

auto shape = std::make_shared<opset8::ShapeOf>(context.get_input(0), element::i32);
auto rank_ = std::make_shared<opset8::ShapeOf>(shape, element::i32);
auto rank = std::make_shared<opset8::Squeeze>(rank_);
// Use opset::If for dim normalization
auto start_dim_node = context.get_input(1);
auto end_dim_node = context.get_input(2);
if (start_dim < 0) {
start_dim_node = std::make_shared<opset8::Add>(rank, start_dim_node);
}
if (end_dim < 0) {
end_dim_node = std::make_shared<opset8::Add>(rank, end_dim_node);
}
auto delta = std::make_shared<opset8::Subtract>(end_dim_node, start_dim_node);
auto rank_delta = std::make_shared<opset8::Subtract>(rank, delta);
auto true_const0 = opset8::Constant::create(element::boolean, Shape{}, {1});
auto zeros_loop = std::make_shared<opset8::Loop>(rank_delta, true_const0);
auto true_const = opset8::Constant::create(element::boolean, Shape{}, {1});
auto result_true = std::make_shared<opset8::Result>(true_const);
auto zero_const = opset8::Constant::create(element::i32, Shape{1}, {0});
auto result_zero = std::make_shared<opset8::Result>(zero_const);
auto f = std::make_shared<ov::Model>(ResultVector{result_true, result_zero}, ParameterVector{});
zeros_loop->set_function(f);
zeros_loop->set_special_body_ports({-1, 0});
auto zeros = zeros_loop->get_concatenated_slices(result_zero, 0, 1, 1, -1, 0);
auto neg_1_const = opset8::Constant::create(element::i32, Shape{1}, {-1});
auto axis_0 = opset8::Constant::create(element::i32, Shape{1}, {0});
auto start_dim_node_ = std::make_shared<opset8::Unsqueeze>(start_dim_node, axis_0);
auto new_shape = std::make_shared<opset8::ScatterElementsUpdate>(zeros, start_dim_node_, neg_1_const, axis_0);

context.mark_nodes({shape,
rank_,
rank,
delta,
rank_delta,
true_const0,
zeros_loop,
neg_1_const,
axis_0,
start_dim_node_,
new_shape});

return {context.mark_node(std::make_shared<opset8::Reshape>(context.get_input(0), new_shape, true))};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
29 changes: 6 additions & 23 deletions src/frontends/pytorch/src/op/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
// SPDX-License-Identifier: Apache-2.0
//

#include <climits>

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset8.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
Expand All @@ -33,28 +30,14 @@ OutputVector translate_transpose(NodeContext& context) {
auto range = std::make_shared<opset8::Range>(start, rank, step, element::i32);

auto axis_0 = opset8::Constant::create(element::i64, Shape{}, {0});
dim0_node = std::make_shared<opset8::Unsqueeze>(dim0_node, axis_0);
dim1_node = std::make_shared<opset8::Unsqueeze>(dim1_node, axis_0);
auto indices = std::make_shared<opset8::Concat>(OutputVector{dim0_node, dim1_node}, 0);
auto updates = std::make_shared<opset8::Concat>(OutputVector{dim1_node, dim0_node}, 0);
auto dim0_node_ = std::make_shared<opset8::Unsqueeze>(dim0_node, axis_0);
auto dim1_node_ = std::make_shared<opset8::Unsqueeze>(dim1_node, axis_0);
auto indices = std::make_shared<opset8::Concat>(OutputVector{dim0_node_, dim1_node_}, 0);
auto updates = std::make_shared<opset8::Concat>(OutputVector{dim1_node_, dim0_node_}, 0);
auto scatter = std::make_shared<opset8::ScatterElementsUpdate>(range, indices, updates, axis_0);
context.mark_nodes(
{shape, rank_, rank, start, step, range, axis_0, dim0_node_, dim1_node_, indices, updates, scatter});

/*auto data_pshape = context.get_input(0).get_partial_shape();
auto rank = data_pshape.rank();
OV_FRONTEND_REQUIRE(rank.is_static());
auto _rank = rank.get_length();
if (dim0 < 0) {
dim0 = _rank + dim0;
}
if (dim1 < 0) {
dim1 = _rank + dim1;
}
OV_FRONTEND_REQUIRE(dim0 > 0 && dim1 > 0);
OV_FRONTEND_REQUIRE(dim0 < _rank && dim1 < _rank);
std::vector<int64_t> order(_rank, 0);
std::iota(order.begin(), order.end(), 0);
std::swap(order[dim0], order[dim1]);
auto order_const = context.mark_node(opset8::Constant::create(element::i64, {order.size()}, order));*/
return {context.mark_node(std::make_shared<opset8::Transpose>(context.get_input(0), scatter))};
};

Expand Down
23 changes: 2 additions & 21 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace op {

#define OP_CONVERTER(op) OutputVector op(NodeContext& node)

OP_CONVERTER(translate_flatten);
OP_CONVERTER(translate_if);
OP_CONVERTER(translate_loop);
OP_CONVERTER(translate_slice);
Expand Down Expand Up @@ -415,27 +416,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
std::make_shared<opset8::ReduceMean>(context.get_input(0), context.get_input(1), keep_dims))};
}},

{"aten::flatten",
[](NodeContext& context) -> OutputVector {
auto start_dim = context.const_input<int64_t>(1);
auto end_dim = context.const_input<int64_t>(2);
auto data_pshape = context.get_input(0).get_partial_shape();
OV_FRONTEND_REQUIRE(data_pshape.rank().is_static()); // TODO: support dynamic rank
auto rank = data_pshape.rank().get_length();
if (start_dim < 0) {
start_dim = rank + start_dim;
}
if (end_dim < 0) {
end_dim = rank + end_dim;
}
OV_FRONTEND_REQUIRE(start_dim < end_dim);
auto delta = end_dim - start_dim;
std::vector<int64_t> new_shape(rank - delta, 0);
new_shape[start_dim] = -1;
auto new_shape_const =
context.mark_node(opset8::Constant::create(element::i64, {new_shape.size()}, new_shape));
return {context.mark_node(std::make_shared<opset8::Reshape>(context.get_input(0), new_shape_const, true))};
}},
{"aten::flatten", op::translate_flatten},

{"prim::NumToTensor",
[](NodeContext& context) -> OutputVector {
Expand Down

0 comments on commit 89281ad

Please sign in to comment.