Skip to content

Commit

Permalink
[FrontEnd][PaddlePaddle] fix fill_constant_batch_size_like when attri…
Browse files Browse the repository at this point in the history
…bute str_value be empty.

This happens when export ppyolo with PaddleDetection release/2.2.
  • Loading branch information
ceciliapeng2011 committed Aug 30, 2021
1 parent 4a07a0b commit dda4fee
Showing 1 changed file with 41 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,51 @@ static std::shared_ptr<Node> set_val(int32_t idx, std::shared_ptr<Node> val_node
return std::make_shared<ngraph::opset6::Concat>(nodes, 0);
}

// "str_value" type string, default empty. "value" type float, default 0. They are value to be filled.
// if "str_value" empty, will fill with "value", else fill with "str_value".
// reference:
// https://github.com/PaddlePaddle/Paddle/blob/93d862b0adf224a0af547d1442c57fbd6d0e8efc/paddle/fluid/operators/fill_constant_batch_size_like_op.cc#L41
// https://github.com/PaddlePaddle/Paddle/blob/93d862b0adf224a0af547d1442c57fbd6d0e8efc/paddle/fluid/operators/fill_constant_batch_size_like_op.h#L44
static Output<Node> get_seed_node(const NodeContext& node) {
auto dtype = node.get_attribute<element::Type>("dtype");
Output<Node> val_node;
auto str_value = node.get_attribute<std::string>("str_value");
switch (dtype) {
case element::i32:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {std::stoi(str_value)});
break;
case element::i64:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {std::stoll(str_value)});
break;
case element::f32:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {std::stof(str_value)});
break;
case element::f64:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {std::stod(str_value)});
break;
default:
throw std::runtime_error("fill_constant_batch_size_like: dtype value is invalid");

if (str_value.empty()) {
auto float_value = node.get_attribute<float>("value");
switch (dtype) {
case element::i32:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {static_cast<int>(float_value)});
break;
case element::i64:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {static_cast<long long>(float_value)});
break;
case element::f32:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {float_value});
break;
case element::f64:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {static_cast<double>(float_value)});
break;
default:
throw std::runtime_error("fill_constant_batch_size_like: dtype value is invalid");
}
} else {
switch (dtype) {
case element::i32:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {std::stoi(str_value)});
break;
case element::i64:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {std::stoll(str_value)});
break;
case element::f32:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {std::stof(str_value)});
break;
case element::f64:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {std::stod(str_value)});
break;
default:
throw std::runtime_error("fill_constant_batch_size_like: dtype value is invalid");
}
}

return val_node;
Expand Down

0 comments on commit dda4fee

Please sign in to comment.