diff --git a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py index c4ec195511ab0c..e28f9f6171e793 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py @@ -471,7 +471,7 @@ def get_shape_for_value(self, value): print(f'[ FX DECODER DEBUG ] Decoder method called: {inspect.currentframe().f_code.co_name}') if value and ('tensor_meta' in value.meta.keys()): return PartialShape(value.meta['tensor_meta'].shape) - return PartialShape.dynamic() + return PartialShape([1]) def get_type_for_value(self, value): print(f'[ FX DECODER DEBUG ] Decoder method called: {inspect.currentframe().f_code.co_name}') @@ -480,7 +480,7 @@ def get_type_for_value(self, value): if pt_type in pt_to_ov_type_map: ov_type = pt_to_ov_type_map[pt_type] return OVAny(ov_type) - return OVAny(OVType.dynamic) + return OVAny(OVType.f32) def get_input_transpose_order(self, index): print(f'[ FX DECODER DEBUG ] Decoder method called: {inspect.currentframe().f_code.co_name}') diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/execute.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/execute.py index 31d7c43aa8e13d..d093433c701c3a 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/execute.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/execute.py @@ -73,11 +73,6 @@ def openvino_execute(gm: GraphModule, *args, executor_parameters=None, partition results1 = [res[out] for out in compiled.outputs] results = torch.from_numpy(np.array(results1, dtype=np.float32)) - flat_res, unflatten_spec = tree_flatten(results) - if (len(results) != 2): - results = torch.squeeze(results, 0) - else: - results = torch.flatten(results, end_dim=1) return results diff --git a/src/frontends/pytorch/src/op/max_poolnd.cpp b/src/frontends/pytorch/src/op/max_poolnd.cpp index f594b0a2b0798c..0934611f76dcf2 100644 --- a/src/frontends/pytorch/src/op/max_poolnd.cpp +++ b/src/frontends/pytorch/src/op/max_poolnd.cpp @@ -14,12 +14,20 @@ namespace op { using namespace ov::op; OutputVector translate_max_poolnd(NodeContext& context) { - num_inputs_check(context, 6, 6); + num_inputs_check(context, 4, 6); auto kernel = context.const_input(1); auto strides = context.const_input(2); auto pads = context.const_input(3); // pytorch supports only symmetric paddings - auto dilations = context.const_input(4); - auto rounding_type = context.const_input(5) ? RoundingType::CEIL : RoundingType::FLOOR; + Strides dilations; + if (!context.input_is_none(4)) { + dilations = context.const_input(4); + } + RoundingType rounding_type; + if (context.input_is_none(5)) { + rounding_type = RoundingType::FLOOR; + } else { + rounding_type = context.const_input(5) ? RoundingType::CEIL : RoundingType::FLOOR; + } return {context.mark_node( std::make_shared(context.get_input(0), strides, dilations, pads, pads, kernel, rounding_type))}; @@ -28,4 +36,4 @@ OutputVector translate_max_poolnd(NodeContext& context) { } // namespace op } // namespace pytorch } // namespace frontend -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 788d72c009664d..8c79b3162dc20d 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -251,7 +251,7 @@ const std::map get_supported_ops() { {"aten::max", op::translate_max}, {"aten::max_pool1d", op::translate_max_poolnd}, {"aten::max_pool2d", op::translate_max_poolnd}, - {"aten.max_pool2d_with_indices.default", op::translate_adaptive_max_pool2d}, + {"aten.max_pool2d_with_indices.default", op::translate_max_poolnd}, {"aten::max_pool3d", op::translate_max_poolnd}, {"aten::mean", op::translate_mean}, {"aten.mean.dim", op::translate_mean}, diff --git a/src/frontends/pytorch/src/translate_session.cpp b/src/frontends/pytorch/src/translate_session.cpp index f258284abba947..e4ddd31e5be922 100644 --- a/src/frontends/pytorch/src/translate_session.cpp +++ b/src/frontends/pytorch/src/translate_session.cpp @@ -124,9 +124,9 @@ std::shared_ptr TranslateSession::convert_pytorch_model( // Linkage to external scope will be performed on the level of the parent operation (if or loop) // TODO: Eliminate duplication with the main code for Parameters creation PartialShape ps = node->get_input_shape(i); - auto type = simplified_type_interpret(node->get_input_type(i)); + auto type = simplified_type_interpret(node->get_input_type(i)).as(); // TODO: Use special API to set custom type specification - auto parameter = std::make_shared(element::dynamic, ps); + auto parameter = std::make_shared(type, ps); // TODO: Missing get_input_transpose_order handling for not trivial layouts tensor_map[input] = parameter; // set name of parameter to the index of node in the model