diff --git a/ngraph/frontend/paddlepaddle/src/op/slice.cpp b/ngraph/frontend/paddlepaddle/src/op/slice.cpp index 5dd14d0179b212..e1c1af63512780 100644 --- a/ngraph/frontend/paddlepaddle/src/op/slice.cpp +++ b/ngraph/frontend/paddlepaddle/src/op/slice.cpp @@ -18,43 +18,64 @@ namespace ngraph { auto data = node.get_ng_input("Input"); auto axes = node.get_attribute>("axes"); - // TODO: support tensor type #55266 - auto starts = node.get_attribute>("starts"); - // TODO: support tensor type #55266 - auto ends = node.get_attribute>("ends"); - auto data_rank = data.get_partial_shape().rank(); - size_t shape_size = data_rank.get_length(); - std::vector fixedStarts(shape_size, 0); - std::vector fixedEnds(shape_size, INT_MAX); + Output start_idx_node, end_idx_node; + if (node.has_ng_input("StartsTensor")) + { + start_idx_node = node.get_ng_input("StartsTensor"); + } + else if (node.has_ng_input("StartsTensorList")) + { + auto inputs = node.get_ng_inputs("StartsTensorList"); + start_idx_node = std::make_shared(inputs, 0); + } + else + { + auto starts = node.get_attribute>("starts"); + start_idx_node = + opset6::Constant::create(element::i32, {starts.size()}, starts); + } - int n = 0; - for (auto i : axes) + if (node.has_ng_input("EndsTensor")) { - PDPD_OP_VALIDATION_CHECK(node, - i < (int32_t)shape_size, - "slice: axes must be less than the X rank."); - fixedStarts[i] = starts[n]; - fixedEnds[i] = ends[n]; - n++; + end_idx_node = node.get_ng_input("EndsTensor"); } + else if (node.has_ng_input("EndsTensorList")) + { + auto inputs = node.get_ng_inputs("EndsTensorList"); + end_idx_node = std::make_shared(inputs, 0); + } + else + { + auto ends = node.get_attribute>("ends"); + end_idx_node = opset6::Constant::create(element::i32, {ends.size()}, ends); + } + + // the shape of input, such as [1, 1, 3, 3] + auto shape_node = std::make_shared(data, element::Type_t::i32); + // the input dim, such as [4] + auto shape_shape_node = + std::make_shared(shape_node, element::i32); + auto const_0_node = opset6::Constant::create(element::i32, {}, {0}); + auto const_max_node = opset6::Constant::create(element::i32, {}, {INT_MAX}); + // array [0:max) + auto start_node = + std::make_shared(const_0_node, shape_shape_node); + auto end_node = + std::make_shared(const_max_node, shape_shape_node); + auto axes_node = opset6::Constant::create(element::i32, {axes.size(), 1}, axes); + auto fixed_start_node = std::make_shared( + start_node, axes_node, start_idx_node); + auto fixed_end_node = std::make_shared( + end_node, axes_node, end_idx_node); - auto startsNode = ngraph::opset6::Constant::create( - ngraph::element::i32, {shape_size}, fixedStarts); - auto endsNode = ngraph::opset6::Constant::create( - ngraph::element::i32, {shape_size}, fixedEnds); - auto stridesNode = ngraph::opset6::Constant::create( - ngraph::element::i32, {shape_size}, std::vector(shape_size, 1)); return node.default_single_output_mapping( - {std::make_shared( - data, - startsNode, - endsNode, - stridesNode, - std::vector(shape_size, 0), - std::vector(shape_size, 0))}, + {std::make_shared(data, + fixed_start_node, + fixed_end_node, + std::vector{}, + std::vector{})}, {"Out"}); } - } // namespace op } // namespace pdpd } // namespace frontend diff --git a/ngraph/frontend/paddlepaddle/src/op/unsqueeze.cpp b/ngraph/frontend/paddlepaddle/src/op/unsqueeze.cpp index dc9b7e0bb9c2e9..fad5e489faea99 100644 --- a/ngraph/frontend/paddlepaddle/src/op/unsqueeze.cpp +++ b/ngraph/frontend/paddlepaddle/src/op/unsqueeze.cpp @@ -15,15 +15,26 @@ namespace ngraph { NamedOutputs unsqueeze(const NodeContext& node) { - // TODO to support data type other than int32_t #55168 auto data = node.get_ng_input("X"); - auto axes = node.get_attribute>("axes"); - auto axesNode = - ngraph::opset6::Constant::create(ngraph::element::i32, {axes.size()}, axes); + Output axesNode; + if (node.has_ng_input("AxesTensor")) + { + axesNode = node.get_ng_input("AxesTensor"); + } + else if (node.has_ng_input("AxesTensorList")) + { + auto inputs = node.get_ng_inputs("AxesTensorList"); + axesNode = std::make_shared(inputs, 0); + } + else + { + auto axes = node.get_attribute>("axes"); + axesNode = ngraph::opset6::Constant::create( + ngraph::element::i32, {axes.size()}, axes); + } return node.default_single_output_mapping( {std::make_shared(data, axesNode)}, {"Out"}); } - } // namespace op } // namespace pdpd } // namespace frontend