Skip to content

Commit

Permalink
param support tensor (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
luo-cheng2021 authored and zhangYiIntel committed Jul 21, 2021
1 parent 97bc0db commit 9ac2caa
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 56 deletions.
121 changes: 100 additions & 21 deletions ngraph/frontend/paddlepaddle/src/op/fill_constant_batch_size_like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
//

#include "fill_constant_batch_size_like.hpp"
#include <limits.h>
#include <ngraph/opsets/opset6.hpp>
#include <paddlepaddle_frontend/utility.hpp>

namespace ngraph
{
Expand All @@ -13,32 +15,109 @@ namespace ngraph
{
namespace op
{
static std::shared_ptr<Node> get_val(int32_t idx, const Output<Node>& data)
{
auto startsNode = ngraph::opset6::Constant::create(element::i32, {1}, {idx});
auto endsNode = ngraph::opset6::Constant::create(element::i32, {1}, {idx + 1});
auto stridesNode = ngraph::opset6::Constant::create(element::i32, {1}, {1});
return std::make_shared<ngraph::opset6::StridedSlice>(
data,
startsNode,
endsNode,
stridesNode,
std::vector<int64_t>(1, 0),
std::vector<int64_t>(1, 0));
}

static std::shared_ptr<Node> set_val(int32_t idx,
std::shared_ptr<Node> val_node,
std::shared_ptr<Node> array_node)
{
NodeVector nodes;
if (idx > 0)
{
// [0, idx)
auto startsNode = ngraph::opset6::Constant::create(element::i32, {1}, {0});
auto endsNode = ngraph::opset6::Constant::create(element::i32, {1}, {idx});
auto stridesNode = ngraph::opset6::Constant::create(element::i32, {1}, {1});
auto head = std::make_shared<ngraph::opset6::StridedSlice>(
array_node,
startsNode,
endsNode,
stridesNode,
std::vector<int64_t>(1, 0),
std::vector<int64_t>(1, 0));
nodes.push_back(head);
}
nodes.push_back(val_node);
// [idx + 1, max)
auto startsNode =
ngraph::opset6::Constant::create(element::i32, {1}, {idx + 1});
auto endsNode = ngraph::opset6::Constant::create(element::i32, {1}, {INT_MAX});
auto stridesNode = ngraph::opset6::Constant::create(element::i32, {1}, {1});
auto tail =
std::make_shared<ngraph::opset6::StridedSlice>(array_node,
startsNode,
endsNode,
stridesNode,
std::vector<int64_t>(1, 0),
std::vector<int64_t>(1, 0));
nodes.push_back(tail);

return std::make_shared<ngraph::opset6::Concat>(nodes, 0);
}

static Output<Node> get_seed_node(const NodeContext& node)
{
auto dtype = node.get_attribute<element::Type>("dtype");
Output<Node> val_node;
auto str_value = node.get_attribute<std::string>("str_value");
switch (dtype)
{
case element::i32:
val_node =
ngraph::opset6::Constant::create(dtype, {1}, {std::stoi(str_value)});
break;
case element::i64:
val_node =
ngraph::opset6::Constant::create(dtype, {1}, {std::stoll(str_value)});
break;
case element::f32:
val_node =
ngraph::opset6::Constant::create(dtype, {1}, {std::stof(str_value)});
break;
case element::f64:
val_node =
ngraph::opset6::Constant::create(dtype, {1}, {std::stod(str_value)});
break;
default:
throw std::runtime_error(
"fill_constant_batch_size_like: dtype value is invalid");
}

return val_node;
}

NamedOutputs fill_constant_batch_size_like(const NodeContext& node)
{
// TODO to Support other data types other than FP32 #55263
auto input_dim_idx = node.get_attribute<int32_t>("input_dim_idx", 0);
auto output_dim_idx = node.get_attribute<int32_t>("output_dim_idx", 0);
auto value = node.get_attribute<float>("value");
auto input_dim_idx = node.get_attribute<int32_t>("input_dim_idx");
auto output_dim_idx = node.get_attribute<int32_t>("output_dim_idx");
auto shapes = node.get_attribute<std::vector<int32_t>>("shape");
auto input = node.get_ng_input("Input");
auto partial_shape = input.get_partial_shape();
PDPD_OP_VALIDATION_CHECK(
node,
partial_shape.is_static(),
"fill_constant_batch_size_like: must use static shape.");
auto static_shape = partial_shape.get_shape();
PDPD_OP_VALIDATION_CHECK(node,
input_dim_idx < (int32_t)static_shape.size(),
"fill_constant_batch_size_like: input_dim_idx "
"should not exceed input dims.");
PDPD_OP_VALIDATION_CHECK(node,
"fill_constant_batch_size_like: output_dim_idx "
"should not exceed shapes dims.");
shapes[output_dim_idx] = static_shape[input_dim_idx];
auto dtype = node.get_attribute<element::Type>("dtype");
auto input_shape =
std::make_shared<ngraph::opset6::ShapeOf>(input, element::i32);
// 1, cat the array:
// shape[0, shape[output_dim_idx]) + input_shape[input_dim_idx] +
// shape[shape[output_dim_idx + 1], -1]
auto input_val_node = get_val(input_dim_idx, input_shape);
auto shapes_node = ngraph::opset6::Constant::create(
ngraph::element::i32, {shapes.size()}, shapes);
auto shape_node = set_val(output_dim_idx, input_val_node, shapes_node);

// 2, use the shape broadcast the node
auto val_node = get_seed_node(node);
return node.default_single_output_mapping(
{std::make_shared<ngraph::opset6::Constant>(
dtype, Shape(shapes.begin(), shapes.end()), value)},
{std::make_shared<ngraph::opset6::Broadcast>(val_node, shape_node)},
{"Out"});
}

Expand Down
81 changes: 51 additions & 30 deletions ngraph/frontend/paddlepaddle/src/op/slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,43 +18,64 @@ namespace ngraph
{
auto data = node.get_ng_input("Input");
auto axes = node.get_attribute<std::vector<int32_t>>("axes");
// TODO: support tensor type #55266
auto starts = node.get_attribute<std::vector<int32_t>>("starts");
// TODO: support tensor type #55266
auto ends = node.get_attribute<std::vector<int32_t>>("ends");
auto data_rank = data.get_partial_shape().rank();
size_t shape_size = data_rank.get_length();
std::vector<int32_t> fixedStarts(shape_size, 0);
std::vector<int32_t> fixedEnds(shape_size, INT_MAX);
Output<Node> start_idx_node, end_idx_node;
if (node.has_ng_input("StartsTensor"))
{
start_idx_node = node.get_ng_input("StartsTensor");
}
else if (node.has_ng_input("StartsTensorList"))
{
auto inputs = node.get_ng_inputs("StartsTensorList");
start_idx_node = std::make_shared<ngraph::opset6::Concat>(inputs, 0);
}
else
{
auto starts = node.get_attribute<std::vector<int32_t>>("starts");
start_idx_node =
opset6::Constant::create(element::i32, {starts.size()}, starts);
}

int n = 0;
for (auto i : axes)
if (node.has_ng_input("EndsTensor"))
{
PDPD_OP_VALIDATION_CHECK(node,
i < (int32_t)shape_size,
"slice: axes must be less than the X rank.");
fixedStarts[i] = starts[n];
fixedEnds[i] = ends[n];
n++;
end_idx_node = node.get_ng_input("EndsTensor");
}
else if (node.has_ng_input("EndsTensorList"))
{
auto inputs = node.get_ng_inputs("EndsTensorList");
end_idx_node = std::make_shared<ngraph::opset6::Concat>(inputs, 0);
}
else
{
auto ends = node.get_attribute<std::vector<int32_t>>("ends");
end_idx_node = opset6::Constant::create(element::i32, {ends.size()}, ends);
}

// the shape of input, such as [1, 1, 3, 3]
auto shape_node = std::make_shared<opset6::ShapeOf>(data, element::Type_t::i32);
// the input dim, such as [4]
auto shape_shape_node =
std::make_shared<opset6::ShapeOf>(shape_node, element::i32);
auto const_0_node = opset6::Constant::create(element::i32, {}, {0});
auto const_max_node = opset6::Constant::create(element::i32, {}, {INT_MAX});
// array [0:max)
auto start_node =
std::make_shared<opset6::Broadcast>(const_0_node, shape_shape_node);
auto end_node =
std::make_shared<opset6::Broadcast>(const_max_node, shape_shape_node);
auto axes_node = opset6::Constant::create(element::i32, {axes.size(), 1}, axes);
auto fixed_start_node = std::make_shared<opset6::ScatterNDUpdate>(
start_node, axes_node, start_idx_node);
auto fixed_end_node = std::make_shared<opset6::ScatterNDUpdate>(
end_node, axes_node, end_idx_node);

auto startsNode = ngraph::opset6::Constant::create(
ngraph::element::i32, {shape_size}, fixedStarts);
auto endsNode = ngraph::opset6::Constant::create(
ngraph::element::i32, {shape_size}, fixedEnds);
auto stridesNode = ngraph::opset6::Constant::create(
ngraph::element::i32, {shape_size}, std::vector<int32_t>(shape_size, 1));
return node.default_single_output_mapping(
{std::make_shared<ngraph::opset6::StridedSlice>(
data,
startsNode,
endsNode,
stridesNode,
std::vector<int64_t>(shape_size, 0),
std::vector<int64_t>(shape_size, 0))},
{std::make_shared<ngraph::opset6::StridedSlice>(data,
fixed_start_node,
fixed_end_node,
std::vector<int64_t>{},
std::vector<int64_t>{})},
{"Out"});
}

} // namespace op
} // namespace pdpd
} // namespace frontend
Expand Down
21 changes: 16 additions & 5 deletions ngraph/frontend/paddlepaddle/src/op/unsqueeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,26 @@ namespace ngraph
{
NamedOutputs unsqueeze(const NodeContext& node)
{
// TODO to support data type other than int32_t #55168
auto data = node.get_ng_input("X");
auto axes = node.get_attribute<std::vector<int32_t>>("axes");
auto axesNode =
ngraph::opset6::Constant::create(ngraph::element::i32, {axes.size()}, axes);
Output<Node> axesNode;
if (node.has_ng_input("AxesTensor"))
{
axesNode = node.get_ng_input("AxesTensor");
}
else if (node.has_ng_input("AxesTensorList"))
{
auto inputs = node.get_ng_inputs("AxesTensorList");
axesNode = std::make_shared<ngraph::opset6::Concat>(inputs, 0);
}
else
{
auto axes = node.get_attribute<std::vector<int32_t>>("axes");
axesNode = ngraph::opset6::Constant::create(
ngraph::element::i32, {axes.size()}, axes);
}
return node.default_single_output_mapping(
{std::make_shared<ngraph::opset6::Unsqueeze>(data, axesNode)}, {"Out"});
}

} // namespace op
} // namespace pdpd
} // namespace frontend
Expand Down

0 comments on commit 9ac2caa

Please sign in to comment.