From 3f9a245a152cf7f3dc83f07a2e5fd4f2e86b801b Mon Sep 17 00:00:00 2001 From: Cavus Mustafa Date: Thu, 29 Feb 2024 23:34:06 -0800 Subject: [PATCH] Stack translation fix for TorchFX --- src/frontends/pytorch/src/op/cat.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/frontends/pytorch/src/op/cat.cpp b/src/frontends/pytorch/src/op/cat.cpp index 33d2f5f18a20fd..a200ad3f997acf 100644 --- a/src/frontends/pytorch/src/op/cat.cpp +++ b/src/frontends/pytorch/src/op/cat.cpp @@ -102,7 +102,7 @@ OutputVector translate_quantized_cat(const NodeContext& context) { }; OutputVector translate_stack_fx(const NodeContext& context) { - num_inputs_check(context, 2, context.get_input_size()); + num_inputs_check(context, 1, context.get_input_size()); auto dim = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); std::deque> list_elems; auto num_elements = context.get_input_size(); @@ -112,14 +112,12 @@ OutputVector translate_stack_fx(const NodeContext& context) { list_elems.push_back(stack_input); } int64_t axis = 0; - if (context.get_input_size() > 2) - axis = context.const_input(context.get_input_size() - 1); - if (!context.get_input_type(context.get_input_size() - 1).is()) { + if (!context.get_input_type(num_elements - 1).is()) { // axis can be not present and that means that last input will have List type - axis = context.const_input(context.get_input_size() - 1); + axis = context.const_input(num_elements - 1); } else { auto stack_input = - context.mark_node(std::make_shared(context.get_input(static_cast(context.get_input_size() - 1)), dim)); + context.mark_node(std::make_shared(context.get_input(static_cast(num_elements - 1)), dim)); list_elems.push_back(stack_input); } return translate_cat_common(context, list_elems, axis, true);