Skip to content

Commit

Permalink
[FrontEnd][PaddlePaddle] fix fill_constant_batch_size_like when attri… (
Browse files Browse the repository at this point in the history
#7214)

* [FrontEnd][PaddlePaddle] fix fill_constant_batch_size_like when attribute str_value be empty.

This happens when export ppyolo with PaddleDetection release/2.2.

* code refactor.

* remove uncertain comments
  • Loading branch information
ceciliapeng2011 authored Sep 7, 2021
1 parent f99bf64 commit 8eeee5e
Showing 1 changed file with 27 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,46 @@ static std::shared_ptr<Node> set_val(int32_t idx, std::shared_ptr<Node> val_node
return std::make_shared<ngraph::opset6::Concat>(nodes, 0);
}

template <element::Type_t Type,
typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type == element::Type_t::i32 || Type == element::Type_t::i64 ||
Type == element::Type_t::f32 || Type == element::Type_t::f64,
bool>::type = true>
static Output<Node> get_seed_node(const NodeContext& node) {
auto dtype = node.get_attribute<element::Type>("dtype");
Output<Node> val_node;
auto dtype = node.get_attribute<element::Type>("dtype");
auto str_value = node.get_attribute<std::string>("str_value");
if (str_value.empty()) {
auto float_value = node.get_attribute<float>("value");
val_node = ngraph::opset6::Constant::create(dtype, {1}, {static_cast<StorageDataType>(float_value)});
} else {
std::stringstream ss(str_value);
StorageDataType tmp_value;
ss >> tmp_value;
val_node = ngraph::opset6::Constant::create(dtype, {1}, {static_cast<StorageDataType>(tmp_value)});
}
return val_node;
}

static Output<Node> get_seed_node(const NodeContext& node) {
Output<Node> val_node;
auto dtype = node.get_attribute<element::Type>("dtype");

switch (dtype) {
case element::i32:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {std::stoi(str_value)});
val_node = get_seed_node<element::i32>(node);
break;
case element::i64:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {std::stoll(str_value)});
val_node = get_seed_node<element::i64>(node);
break;
case element::f32:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {std::stof(str_value)});
val_node = get_seed_node<element::f32>(node);
break;
case element::f64:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {std::stod(str_value)});
val_node = get_seed_node<element::f64>(node);
break;
default:
throw std::runtime_error("fill_constant_batch_size_like: dtype value is invalid");
throw std::runtime_error("fill_constant_batch_size_like: unsupported dtype");
}

return val_node;
Expand Down

0 comments on commit 8eeee5e

Please sign in to comment.