From 9ac2caaf250deb3aae1a0989ea5ea86ca4e6af95 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Mon, 19 Jul 2021 18:10:59 +0800 Subject: [PATCH] param support tensor (#2) --- .../src/op/fill_constant_batch_size_like.cpp | 121 +++++++++++++++--- ngraph/frontend/paddlepaddle/src/op/slice.cpp | 81 +++++++----- .../paddlepaddle/src/op/unsqueeze.cpp | 21 ++- 3 files changed, 167 insertions(+), 56 deletions(-) diff --git a/ngraph/frontend/paddlepaddle/src/op/fill_constant_batch_size_like.cpp b/ngraph/frontend/paddlepaddle/src/op/fill_constant_batch_size_like.cpp index 09ecc31c640847..287d43e6328bfe 100644 --- a/ngraph/frontend/paddlepaddle/src/op/fill_constant_batch_size_like.cpp +++ b/ngraph/frontend/paddlepaddle/src/op/fill_constant_batch_size_like.cpp @@ -3,7 +3,9 @@ // #include "fill_constant_batch_size_like.hpp" +#include #include +#include namespace ngraph { @@ -13,32 +15,109 @@ namespace ngraph { namespace op { + static std::shared_ptr get_val(int32_t idx, const Output& data) + { + auto startsNode = ngraph::opset6::Constant::create(element::i32, {1}, {idx}); + auto endsNode = ngraph::opset6::Constant::create(element::i32, {1}, {idx + 1}); + auto stridesNode = ngraph::opset6::Constant::create(element::i32, {1}, {1}); + return std::make_shared( + data, + startsNode, + endsNode, + stridesNode, + std::vector(1, 0), + std::vector(1, 0)); + } + + static std::shared_ptr set_val(int32_t idx, + std::shared_ptr val_node, + std::shared_ptr array_node) + { + NodeVector nodes; + if (idx > 0) + { + // [0, idx) + auto startsNode = ngraph::opset6::Constant::create(element::i32, {1}, {0}); + auto endsNode = ngraph::opset6::Constant::create(element::i32, {1}, {idx}); + auto stridesNode = ngraph::opset6::Constant::create(element::i32, {1}, {1}); + auto head = std::make_shared( + array_node, + startsNode, + endsNode, + stridesNode, + std::vector(1, 0), + std::vector(1, 0)); + nodes.push_back(head); + } + nodes.push_back(val_node); + // [idx + 1, max) + auto startsNode = + ngraph::opset6::Constant::create(element::i32, {1}, {idx + 1}); + auto endsNode = ngraph::opset6::Constant::create(element::i32, {1}, {INT_MAX}); + auto stridesNode = ngraph::opset6::Constant::create(element::i32, {1}, {1}); + auto tail = + std::make_shared(array_node, + startsNode, + endsNode, + stridesNode, + std::vector(1, 0), + std::vector(1, 0)); + nodes.push_back(tail); + + return std::make_shared(nodes, 0); + } + + static Output get_seed_node(const NodeContext& node) + { + auto dtype = node.get_attribute("dtype"); + Output val_node; + auto str_value = node.get_attribute("str_value"); + switch (dtype) + { + case element::i32: + val_node = + ngraph::opset6::Constant::create(dtype, {1}, {std::stoi(str_value)}); + break; + case element::i64: + val_node = + ngraph::opset6::Constant::create(dtype, {1}, {std::stoll(str_value)}); + break; + case element::f32: + val_node = + ngraph::opset6::Constant::create(dtype, {1}, {std::stof(str_value)}); + break; + case element::f64: + val_node = + ngraph::opset6::Constant::create(dtype, {1}, {std::stod(str_value)}); + break; + default: + throw std::runtime_error( + "fill_constant_batch_size_like: dtype value is invalid"); + } + + return val_node; + } + NamedOutputs fill_constant_batch_size_like(const NodeContext& node) { - // TODO to Support other data types other than FP32 #55263 - auto input_dim_idx = node.get_attribute("input_dim_idx", 0); - auto output_dim_idx = node.get_attribute("output_dim_idx", 0); - auto value = node.get_attribute("value"); + auto input_dim_idx = node.get_attribute("input_dim_idx"); + auto output_dim_idx = node.get_attribute("output_dim_idx"); auto shapes = node.get_attribute>("shape"); auto input = node.get_ng_input("Input"); - auto partial_shape = input.get_partial_shape(); - PDPD_OP_VALIDATION_CHECK( - node, - partial_shape.is_static(), - "fill_constant_batch_size_like: must use static shape."); - auto static_shape = partial_shape.get_shape(); - PDPD_OP_VALIDATION_CHECK(node, - input_dim_idx < (int32_t)static_shape.size(), - "fill_constant_batch_size_like: input_dim_idx " - "should not exceed input dims."); - PDPD_OP_VALIDATION_CHECK(node, - "fill_constant_batch_size_like: output_dim_idx " - "should not exceed shapes dims."); - shapes[output_dim_idx] = static_shape[input_dim_idx]; - auto dtype = node.get_attribute("dtype"); + auto input_shape = + std::make_shared(input, element::i32); + // 1, cat the array: + // shape[0, shape[output_dim_idx]) + input_shape[input_dim_idx] + + // shape[shape[output_dim_idx + 1], -1] + auto input_val_node = get_val(input_dim_idx, input_shape); + auto shapes_node = ngraph::opset6::Constant::create( + ngraph::element::i32, {shapes.size()}, shapes); + auto shape_node = set_val(output_dim_idx, input_val_node, shapes_node); + + // 2, use the shape broadcast the node + auto val_node = get_seed_node(node); return node.default_single_output_mapping( - {std::make_shared( - dtype, Shape(shapes.begin(), shapes.end()), value)}, + {std::make_shared(val_node, shape_node)}, {"Out"}); } diff --git a/ngraph/frontend/paddlepaddle/src/op/slice.cpp b/ngraph/frontend/paddlepaddle/src/op/slice.cpp index 5dd14d0179b212..e1c1af63512780 100644 --- a/ngraph/frontend/paddlepaddle/src/op/slice.cpp +++ b/ngraph/frontend/paddlepaddle/src/op/slice.cpp @@ -18,43 +18,64 @@ namespace ngraph { auto data = node.get_ng_input("Input"); auto axes = node.get_attribute>("axes"); - // TODO: support tensor type #55266 - auto starts = node.get_attribute>("starts"); - // TODO: support tensor type #55266 - auto ends = node.get_attribute>("ends"); - auto data_rank = data.get_partial_shape().rank(); - size_t shape_size = data_rank.get_length(); - std::vector fixedStarts(shape_size, 0); - std::vector fixedEnds(shape_size, INT_MAX); + Output start_idx_node, end_idx_node; + if (node.has_ng_input("StartsTensor")) + { + start_idx_node = node.get_ng_input("StartsTensor"); + } + else if (node.has_ng_input("StartsTensorList")) + { + auto inputs = node.get_ng_inputs("StartsTensorList"); + start_idx_node = std::make_shared(inputs, 0); + } + else + { + auto starts = node.get_attribute>("starts"); + start_idx_node = + opset6::Constant::create(element::i32, {starts.size()}, starts); + } - int n = 0; - for (auto i : axes) + if (node.has_ng_input("EndsTensor")) { - PDPD_OP_VALIDATION_CHECK(node, - i < (int32_t)shape_size, - "slice: axes must be less than the X rank."); - fixedStarts[i] = starts[n]; - fixedEnds[i] = ends[n]; - n++; + end_idx_node = node.get_ng_input("EndsTensor"); } + else if (node.has_ng_input("EndsTensorList")) + { + auto inputs = node.get_ng_inputs("EndsTensorList"); + end_idx_node = std::make_shared(inputs, 0); + } + else + { + auto ends = node.get_attribute>("ends"); + end_idx_node = opset6::Constant::create(element::i32, {ends.size()}, ends); + } + + // the shape of input, such as [1, 1, 3, 3] + auto shape_node = std::make_shared(data, element::Type_t::i32); + // the input dim, such as [4] + auto shape_shape_node = + std::make_shared(shape_node, element::i32); + auto const_0_node = opset6::Constant::create(element::i32, {}, {0}); + auto const_max_node = opset6::Constant::create(element::i32, {}, {INT_MAX}); + // array [0:max) + auto start_node = + std::make_shared(const_0_node, shape_shape_node); + auto end_node = + std::make_shared(const_max_node, shape_shape_node); + auto axes_node = opset6::Constant::create(element::i32, {axes.size(), 1}, axes); + auto fixed_start_node = std::make_shared( + start_node, axes_node, start_idx_node); + auto fixed_end_node = std::make_shared( + end_node, axes_node, end_idx_node); - auto startsNode = ngraph::opset6::Constant::create( - ngraph::element::i32, {shape_size}, fixedStarts); - auto endsNode = ngraph::opset6::Constant::create( - ngraph::element::i32, {shape_size}, fixedEnds); - auto stridesNode = ngraph::opset6::Constant::create( - ngraph::element::i32, {shape_size}, std::vector(shape_size, 1)); return node.default_single_output_mapping( - {std::make_shared( - data, - startsNode, - endsNode, - stridesNode, - std::vector(shape_size, 0), - std::vector(shape_size, 0))}, + {std::make_shared(data, + fixed_start_node, + fixed_end_node, + std::vector{}, + std::vector{})}, {"Out"}); } - } // namespace op } // namespace pdpd } // namespace frontend diff --git a/ngraph/frontend/paddlepaddle/src/op/unsqueeze.cpp b/ngraph/frontend/paddlepaddle/src/op/unsqueeze.cpp index dc9b7e0bb9c2e9..fad5e489faea99 100644 --- a/ngraph/frontend/paddlepaddle/src/op/unsqueeze.cpp +++ b/ngraph/frontend/paddlepaddle/src/op/unsqueeze.cpp @@ -15,15 +15,26 @@ namespace ngraph { NamedOutputs unsqueeze(const NodeContext& node) { - // TODO to support data type other than int32_t #55168 auto data = node.get_ng_input("X"); - auto axes = node.get_attribute>("axes"); - auto axesNode = - ngraph::opset6::Constant::create(ngraph::element::i32, {axes.size()}, axes); + Output axesNode; + if (node.has_ng_input("AxesTensor")) + { + axesNode = node.get_ng_input("AxesTensor"); + } + else if (node.has_ng_input("AxesTensorList")) + { + auto inputs = node.get_ng_inputs("AxesTensorList"); + axesNode = std::make_shared(inputs, 0); + } + else + { + auto axes = node.get_attribute>("axes"); + axesNode = ngraph::opset6::Constant::create( + ngraph::element::i32, {axes.size()}, axes); + } return node.default_single_output_mapping( {std::make_shared(data, axesNode)}, {"Out"}); } - } // namespace op } // namespace pdpd } // namespace frontend