Skip to content

Commit

Permalink
Merge pull request #3 from cavusmustafa/fx_backend_shape_fix
Browse files Browse the repository at this point in the history
MaxPool update & Output shape fix (Torch FX)
  • Loading branch information
cavusmustafa authored Mar 27, 2023
2 parents b4519cf + f52d5bd commit 035dc66
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/bindings/python/src/openvino/frontend/pytorch/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def get_shape_for_value(self, value):
print(f'[ FX DECODER DEBUG ] Decoder method called: {inspect.currentframe().f_code.co_name}')
if value and ('tensor_meta' in value.meta.keys()):
return PartialShape(value.meta['tensor_meta'].shape)
return PartialShape.dynamic()
return PartialShape([1])

def get_type_for_value(self, value):
print(f'[ FX DECODER DEBUG ] Decoder method called: {inspect.currentframe().f_code.co_name}')
Expand All @@ -480,7 +480,7 @@ def get_type_for_value(self, value):
if pt_type in pt_to_ov_type_map:
ov_type = pt_to_ov_type_map[pt_type]
return OVAny(ov_type)
return OVAny(OVType.dynamic)
return OVAny(OVType.f32)

def get_input_transpose_order(self, index):
print(f'[ FX DECODER DEBUG ] Decoder method called: {inspect.currentframe().f_code.co_name}')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,6 @@ def openvino_execute(gm: GraphModule, *args, executor_parameters=None, partition

results1 = [res[out] for out in compiled.outputs]
results = torch.from_numpy(np.array(results1, dtype=np.float32))
flat_res, unflatten_spec = tree_flatten(results)
if (len(results) != 2):
results = torch.squeeze(results, 0)
else:
results = torch.flatten(results, end_dim=1)
return results


Expand Down
16 changes: 12 additions & 4 deletions src/frontends/pytorch/src/op/max_poolnd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,20 @@ namespace op {
using namespace ov::op;

OutputVector translate_max_poolnd(NodeContext& context) {
num_inputs_check(context, 6, 6);
num_inputs_check(context, 4, 6);
auto kernel = context.const_input<Shape>(1);
auto strides = context.const_input<Strides>(2);
auto pads = context.const_input<Shape>(3); // pytorch supports only symmetric paddings
auto dilations = context.const_input<Strides>(4);
auto rounding_type = context.const_input<bool>(5) ? RoundingType::CEIL : RoundingType::FLOOR;
Strides dilations;
if (!context.input_is_none(4)) {
dilations = context.const_input<Strides>(4);
}
RoundingType rounding_type;
if (context.input_is_none(5)) {
rounding_type = RoundingType::FLOOR;
} else {
rounding_type = context.const_input<bool>(5) ? RoundingType::CEIL : RoundingType::FLOOR;
}

return {context.mark_node(
std::make_shared<v8::MaxPool>(context.get_input(0), strides, dilations, pads, pads, kernel, rounding_type))};
Expand All @@ -28,4 +36,4 @@ OutputVector translate_max_poolnd(NodeContext& context) {
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
} // namespace ov
2 changes: 1 addition & 1 deletion src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ const std::map<std::string, PytorchCreatorFunction> get_supported_ops() {
{"aten::max", op::translate_max},
{"aten::max_pool1d", op::translate_max_poolnd},
{"aten::max_pool2d", op::translate_max_poolnd},
{"aten.max_pool2d_with_indices.default", op::translate_adaptive_max_pool2d},
{"aten.max_pool2d_with_indices.default", op::translate_max_poolnd},
{"aten::max_pool3d", op::translate_max_poolnd},
{"aten::mean", op::translate_mean},
{"aten.mean.dim", op::translate_mean},
Expand Down
4 changes: 2 additions & 2 deletions src/frontends/pytorch/src/translate_session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
// Linkage to external scope will be performed on the level of the parent operation (if or loop)
// TODO: Eliminate duplication with the main code for Parameters creation
PartialShape ps = node->get_input_shape(i);
auto type = simplified_type_interpret(node->get_input_type(i));
auto type = simplified_type_interpret(node->get_input_type(i)).as<element::Type>();
// TODO: Use special API to set custom type specification
auto parameter = std::make_shared<v0::Parameter>(element::dynamic, ps);
auto parameter = std::make_shared<v0::Parameter>(type, ps);
// TODO: Missing get_input_transpose_order handling for not trivial layouts
tensor_map[input] = parameter;
// set name of parameter to the index of node in the model
Expand Down

0 comments on commit 035dc66

Please sign in to comment.