diff --git a/src/frontends/pytorch/src/op/cat.cpp b/src/frontends/pytorch/src/op/cat.cpp index 08f828e9d64385..c0ee57854aeb9b 100644 --- a/src/frontends/pytorch/src/op/cat.cpp +++ b/src/frontends/pytorch/src/op/cat.cpp @@ -104,22 +104,20 @@ OutputVector translate_quantized_cat(const NodeContext& context) { OutputVector translate_stack_fx(const NodeContext& context) { num_inputs_check(context, 1, context.get_input_size()); auto dim = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); + int64_t axis = 0; + std::deque> list_elems; auto num_elements = context.get_input_size(); - for (size_t i = 0; i < num_elements - 1; i++) { - if (context.get_input(i).get_partial_shape().rank() == 1) - dim = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); - auto stack_input = - context.mark_node(std::make_shared(context.get_input(static_cast(i)), dim)); - list_elems.push_back(stack_input); - } - int64_t axis = 0; + if (!context.get_input_type(num_elements - 1).is()) { - // axis can be not present and that means that last input will have List type axis = context.const_input(num_elements - 1); - } else { - auto stack_input = context.mark_node( - std::make_shared(context.get_input(static_cast(num_elements - 1)), dim)); + dim = context.mark_node(v0::Constant::create(element::i32, Shape{}, {axis})); + num_elements -= 1; + } + + for (size_t i = 0; i < num_elements; i++) { + auto stack_input = + context.mark_node(std::make_shared(context.get_input(static_cast(i)), dim)); list_elems.push_back(stack_input); } return translate_cat_common(context, list_elems, axis, true);