Skip to content

Commit

Permalink
[ONNX] remove hardcoded shape in GroupNorm operator (#3682)
Browse files Browse the repository at this point in the history
  • Loading branch information
mateusztabaka authored Dec 22, 2020
1 parent b17e0d4 commit e9b89b0
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,6 @@ namespace ngraph
size_t rank_size = pshape.rank().get_length();
NGRAPH_CHECK(rank_size >= 3, "3-D and above tensors supported only");

if (pshape.is_static())
{
const auto& shape = pshape.to_shape();
std::vector<size_t> new_shape{
shape[0], num_groups, shape[1] / num_groups};
for (size_t i = 2; i < rank_size; i++)
{
new_shape.push_back(shape[i]);
}
return default_opset::Constant::create(
element::i64, Shape{new_shape.size()}, new_shape);
}

auto shape = std::make_shared<default_opset::ShapeOf>(data);
auto splits = builder::opset1::split(shape, rank_size);
auto num_groups_const =
Expand Down Expand Up @@ -92,18 +79,7 @@ namespace ngraph
static_cast<size_t>(node.get_attribute_value<int64_t>("num_groups"));
float eps = node.get_attribute_value<float>("eps", 1e-5);

auto data_pshape = data.get_partial_shape();
std::shared_ptr<ngraph::Node> data_shape_node;
if (data_pshape.is_static())
{
auto shape = data_pshape.to_shape();
data_shape_node = default_opset::Constant::create(
element::u64, Shape{shape.size()}, shape);
}
else
{
data_shape_node = std::make_shared<default_opset::ShapeOf>(data);
}
auto data_shape_node = std::make_shared<default_opset::ShapeOf>(data);
auto data_reshaped = std::make_shared<default_opset::Reshape>(
data, detail::create_group_norm_shape(data, num_groups), true);

Expand Down
2 changes: 1 addition & 1 deletion ngraph/test/onnx/onnx_import.in.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3213,7 +3213,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_group_norm)
{
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/group_norm.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
auto test_case = test::TestCase<TestEngine, test::TestCaseType::DYNAMIC>(function);
Shape shape{2, 8, 2, 2};
int size = shape_size(shape);
std::vector<float> data(size);
Expand Down

0 comments on commit e9b89b0

Please sign in to comment.