Skip to content

Commit

Permalink
param support tensor (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
luo-cheng2021 authored Jul 19, 2021
1 parent a15b5dd commit f8b6b4c
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 35 deletions.
81 changes: 51 additions & 30 deletions ngraph/frontend/paddlepaddle/src/op/slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,43 +18,64 @@ namespace ngraph
{
auto data = node.get_ng_input("Input");
auto axes = node.get_attribute<std::vector<int32_t>>("axes");
// TODO: support tensor type #55266
auto starts = node.get_attribute<std::vector<int32_t>>("starts");
// TODO: support tensor type #55266
auto ends = node.get_attribute<std::vector<int32_t>>("ends");
auto data_rank = data.get_partial_shape().rank();
size_t shape_size = data_rank.get_length();
std::vector<int32_t> fixedStarts(shape_size, 0);
std::vector<int32_t> fixedEnds(shape_size, INT_MAX);
Output<Node> start_idx_node, end_idx_node;
if (node.has_ng_input("StartsTensor"))
{
start_idx_node = node.get_ng_input("StartsTensor");
}
else if (node.has_ng_input("StartsTensorList"))
{
auto inputs = node.get_ng_inputs("StartsTensorList");
start_idx_node = std::make_shared<ngraph::opset6::Concat>(inputs, 0);
}
else
{
auto starts = node.get_attribute<std::vector<int32_t>>("starts");
start_idx_node =
opset6::Constant::create(element::i32, {starts.size()}, starts);
}

int n = 0;
for (auto i : axes)
if (node.has_ng_input("EndsTensor"))
{
PDPD_OP_VALIDATION_CHECK(node,
i < (int32_t)shape_size,
"slice: axes must be less than the X rank.");
fixedStarts[i] = starts[n];
fixedEnds[i] = ends[n];
n++;
end_idx_node = node.get_ng_input("EndsTensor");
}
else if (node.has_ng_input("EndsTensorList"))
{
auto inputs = node.get_ng_inputs("EndsTensorList");
end_idx_node = std::make_shared<ngraph::opset6::Concat>(inputs, 0);
}
else
{
auto ends = node.get_attribute<std::vector<int32_t>>("ends");
end_idx_node = opset6::Constant::create(element::i32, {ends.size()}, ends);
}

// the shape of input, such as [1, 1, 3, 3]
auto shape_node = std::make_shared<opset6::ShapeOf>(data, element::Type_t::i32);
// the input dim, such as [4]
auto shape_shape_node =
std::make_shared<opset6::ShapeOf>(shape_node, element::i32);
auto const_0_node = opset6::Constant::create(element::i32, {}, {0});
auto const_max_node = opset6::Constant::create(element::i32, {}, {INT_MAX});
// array [0:max)
auto start_node =
std::make_shared<opset6::Broadcast>(const_0_node, shape_shape_node);
auto end_node =
std::make_shared<opset6::Broadcast>(const_max_node, shape_shape_node);
auto axes_node = opset6::Constant::create(element::i32, {axes.size(), 1}, axes);
auto fixed_start_node = std::make_shared<opset6::ScatterNDUpdate>(
start_node, axes_node, start_idx_node);
auto fixed_end_node = std::make_shared<opset6::ScatterNDUpdate>(
end_node, axes_node, end_idx_node);

auto startsNode = ngraph::opset6::Constant::create(
ngraph::element::i32, {shape_size}, fixedStarts);
auto endsNode = ngraph::opset6::Constant::create(
ngraph::element::i32, {shape_size}, fixedEnds);
auto stridesNode = ngraph::opset6::Constant::create(
ngraph::element::i32, {shape_size}, std::vector<int32_t>(shape_size, 1));
return node.default_single_output_mapping(
{std::make_shared<ngraph::opset6::StridedSlice>(
data,
startsNode,
endsNode,
stridesNode,
std::vector<int64_t>(shape_size, 0),
std::vector<int64_t>(shape_size, 0))},
{std::make_shared<ngraph::opset6::StridedSlice>(data,
fixed_start_node,
fixed_end_node,
std::vector<int64_t>{},
std::vector<int64_t>{})},
{"Out"});
}

} // namespace op
} // namespace pdpd
} // namespace frontend
Expand Down
21 changes: 16 additions & 5 deletions ngraph/frontend/paddlepaddle/src/op/unsqueeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,26 @@ namespace ngraph
{
NamedOutputs unsqueeze(const NodeContext& node)
{
// TODO to support data type other than int32_t #55168
auto data = node.get_ng_input("X");
auto axes = node.get_attribute<std::vector<int32_t>>("axes");
auto axesNode =
ngraph::opset6::Constant::create(ngraph::element::i32, {axes.size()}, axes);
Output<Node> axesNode;
if (node.has_ng_input("AxesTensor"))
{
axesNode = node.get_ng_input("AxesTensor");
}
else if (node.has_ng_input("AxesTensorList"))
{
auto inputs = node.get_ng_inputs("AxesTensorList");
axesNode = std::make_shared<ngraph::opset6::Concat>(inputs, 0);
}
else
{
auto axes = node.get_attribute<std::vector<int32_t>>("axes");
axesNode = ngraph::opset6::Constant::create(
ngraph::element::i32, {axes.size()}, axes);
}
return node.default_single_output_mapping(
{std::make_shared<ngraph::opset6::Unsqueeze>(data, axesNode)}, {"Out"});
}

} // namespace op
} // namespace pdpd
} // namespace frontend
Expand Down

0 comments on commit f8b6b4c

Please sign in to comment.