Skip to content

Commit

Permalink
code refactor.
Browse files Browse the repository at this point in the history
  • Loading branch information
ceciliapeng2011 committed Aug 30, 2021
1 parent dda4fee commit 75445d3
Showing 1 changed file with 34 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,46 +60,46 @@ static std::shared_ptr<Node> set_val(int32_t idx, std::shared_ptr<Node> val_node
// reference:
// https://github.com/PaddlePaddle/Paddle/blob/93d862b0adf224a0af547d1442c57fbd6d0e8efc/paddle/fluid/operators/fill_constant_batch_size_like_op.cc#L41
// https://github.com/PaddlePaddle/Paddle/blob/93d862b0adf224a0af547d1442c57fbd6d0e8efc/paddle/fluid/operators/fill_constant_batch_size_like_op.h#L44
template <element::Type_t Type,
typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type == element::Type_t::i32 || Type == element::Type_t::i64 ||
Type == element::Type_t::f32 || Type == element::Type_t::f64,
bool>::type = true>
static Output<Node> get_seed_node(const NodeContext& node) {
auto dtype = node.get_attribute<element::Type>("dtype");
Output<Node> val_node;
auto dtype = node.get_attribute<element::Type>("dtype");
auto str_value = node.get_attribute<std::string>("str_value");

if (str_value.empty()) {
auto float_value = node.get_attribute<float>("value");
switch (dtype) {
case element::i32:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {static_cast<int>(float_value)});
break;
case element::i64:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {static_cast<long long>(float_value)});
break;
case element::f32:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {float_value});
break;
case element::f64:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {static_cast<double>(float_value)});
break;
default:
throw std::runtime_error("fill_constant_batch_size_like: dtype value is invalid");
}
val_node = ngraph::opset6::Constant::create(dtype, {1}, {static_cast<StorageDataType>(float_value)});
} else {
switch (dtype) {
case element::i32:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {std::stoi(str_value)});
break;
case element::i64:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {std::stoll(str_value)});
break;
case element::f32:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {std::stof(str_value)});
break;
case element::f64:
val_node = ngraph::opset6::Constant::create(dtype, {1}, {std::stod(str_value)});
break;
default:
throw std::runtime_error("fill_constant_batch_size_like: dtype value is invalid");
}
std::stringstream ss(str_value);
StorageDataType tmp_value;
ss >> tmp_value;
val_node = ngraph::opset6::Constant::create(dtype, {1}, {static_cast<StorageDataType>(tmp_value)});
}
return val_node;
}

static Output<Node> get_seed_node(const NodeContext& node) {
Output<Node> val_node;
auto dtype = node.get_attribute<element::Type>("dtype");

switch (dtype) {
case element::i32:
val_node = get_seed_node<element::i32>(node);
break;
case element::i64:
val_node = get_seed_node<element::i64>(node);
break;
case element::f32:
val_node = get_seed_node<element::f32>(node);
break;
case element::f64:
val_node = get_seed_node<element::f64>(node);
break;
default:
throw std::runtime_error("fill_constant_batch_size_like: unsupported dtype");
}

return val_node;
Expand Down

0 comments on commit 75445d3

Please sign in to comment.