From 4858e21743f99ffa6f0b63a5c322d0ce9c70ba43 Mon Sep 17 00:00:00 2001 From: Mateusz Tabaka Date: Tue, 5 Oct 2021 08:44:53 +0200 Subject: [PATCH 1/2] Add support for ONNX op "com.microsoft.EmbedLayerNormalization" Ticket: 62890 --- .../embed_layer_normalization.cpp | 58 ++++ .../embed_layer_normalization.hpp | 17 ++ .../frontend/onnx/frontend/src/ops_bridge.cpp | 2 + ngraph/test/engines_util/ie_engines.cpp | 3 + .../test/engines_util/interpreter_engine.cpp | 3 + ...ayer_normalization_dynamic_shapes.prototxt | 186 +++++++++++++ .../onnx/embed_layer_normalization.prototxt | 187 +++++++++++++ ...malization_with_segment_embedding.prototxt | 239 ++++++++++++++++ ...n_with_segment_embedding_and_mask.prototxt | 256 ++++++++++++++++++ .../onnx/onnx_import_com_microsoft.in.cpp | 222 ++++++++++++++- ngraph/test/runtime/ie/unit_test.manifest | 2 + 11 files changed, 1171 insertions(+), 4 deletions(-) create mode 100644 ngraph/frontend/onnx/frontend/src/op/com.microsoft/embed_layer_normalization.cpp create mode 100644 ngraph/frontend/onnx/frontend/src/op/com.microsoft/embed_layer_normalization.hpp create mode 100644 ngraph/test/models/onnx/dynamic_shapes/embed_layer_normalization_dynamic_shapes.prototxt create mode 100644 ngraph/test/models/onnx/embed_layer_normalization.prototxt create mode 100644 ngraph/test/models/onnx/embed_layer_normalization_with_segment_embedding.prototxt create mode 100644 ngraph/test/models/onnx/embed_layer_normalization_with_segment_embedding_and_mask.prototxt diff --git a/ngraph/frontend/onnx/frontend/src/op/com.microsoft/embed_layer_normalization.cpp b/ngraph/frontend/onnx/frontend/src/op/com.microsoft/embed_layer_normalization.cpp new file mode 100644 index 00000000000000..d46690375d3e29 --- /dev/null +++ b/ngraph/frontend/onnx/frontend/src/op/com.microsoft/embed_layer_normalization.cpp @@ -0,0 +1,58 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "op/com.microsoft/embed_layer_normalization.hpp" + +#include "default_opset.hpp" +#include "onnx_import/core/null_node.hpp" + +namespace ngraph { +namespace onnx_import { +namespace op { +namespace set_1 { +OutputVector embed_layer_normalization(const Node& node) { + auto nodes = node.get_ng_inputs(); + auto num_nodes = nodes.size(); + NGRAPH_CHECK(num_nodes >= 7 && num_nodes <= 8, + "EmbedLayerNormalization takes 7 or 8 inputs. Provided " + std::to_string(num_nodes)); + NGRAPH_CHECK(nodes[0].get_element_type() == element::i32, "input_ids must have int32 type"); + auto zero = default_opset::Constant::create(element::i32, Shape{1}, {0}); + auto word_embedding = std::make_shared(nodes[2], nodes[0], zero, 0); + auto input = std::make_shared(word_embedding, nodes[3]); + if (!ngraph::op::is_null(nodes[1])) { + NGRAPH_CHECK(!ngraph::op::is_null(nodes[4]), "segment_ids provided, but segment_embedding input is missing"); + NGRAPH_CHECK(nodes[1].get_element_type() == element::i32, "segment_ids must have int32 type"); + auto segment_embedding = std::make_shared(nodes[4], nodes[1], zero, 0); + input = std::make_shared(input, segment_embedding); + } + float eps = node.get_attribute_value("epsilon"); + // reduce over hidden_size + // hidden_size dimension is 2 here, because the shape after Gather(word_embedding, input_ids) + // is (batch_size, seq_len, hidden_size) + int hidden_size_dim = 2; + const auto reduction_axes = default_opset::Constant::create(element::i32, Shape{1}, {hidden_size_dim}); + std::shared_ptr result = + std::make_shared(input, reduction_axes, true, eps, ngraph::op::MVNEpsMode::INSIDE_SQRT); + // multiply by gamma + result = std::make_shared(result, nodes[5]); + // add beta + result = std::make_shared(result, nodes[6]); + std::shared_ptr mask_index; + if (num_nodes > 7 && !ngraph::op::is_null(nodes[7])) { + NGRAPH_CHECK(nodes[7].get_element_type() == element::i32, "mask must have int32 type"); + auto axis = default_opset::Constant::create(element::i32, Shape{}, {1}); + mask_index = std::make_shared(nodes[7], axis, false); + } else { + auto batch_size = std::make_shared(std::make_shared(nodes[0]), + zero, // indices + zero // axis + ); + mask_index = std::make_shared(zero, batch_size); + } + return {result, mask_index}; +} +} // namespace set_1 +} // namespace op +} // namespace onnx_import +} // namespace ngraph diff --git a/ngraph/frontend/onnx/frontend/src/op/com.microsoft/embed_layer_normalization.hpp b/ngraph/frontend/onnx/frontend/src/op/com.microsoft/embed_layer_normalization.hpp new file mode 100644 index 00000000000000..2d9fcecdcf7932 --- /dev/null +++ b/ngraph/frontend/onnx/frontend/src/op/com.microsoft/embed_layer_normalization.hpp @@ -0,0 +1,17 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "onnx_import/core/node.hpp" + +namespace ngraph { +namespace onnx_import { +namespace op { +namespace set_1 { +OutputVector embed_layer_normalization(const Node& node); +} // namespace set_1 +} // namespace op +} // namespace onnx_import +} // namespace ngraph diff --git a/ngraph/frontend/onnx/frontend/src/ops_bridge.cpp b/ngraph/frontend/onnx/frontend/src/ops_bridge.cpp index 20484c78953ffd..9f3ab3cf595a8c 100644 --- a/ngraph/frontend/onnx/frontend/src/ops_bridge.cpp +++ b/ngraph/frontend/onnx/frontend/src/ops_bridge.cpp @@ -30,6 +30,7 @@ #include "op/ceil.hpp" #include "op/clip.hpp" #include "op/com.microsoft/bias_gelu.hpp" +#include "op/com.microsoft/embed_layer_normalization.hpp" #include "op/com.microsoft/skip_layer_normalization.hpp" #include "op/compress.hpp" #include "op/concat.hpp" @@ -480,6 +481,7 @@ OperatorsBridge::OperatorsBridge() { REGISTER_OPERATOR_WITH_DOMAIN(OPENVINO_ONNX_DOMAIN, "Swish", 1, swish); REGISTER_OPERATOR_WITH_DOMAIN(MICROSOFT_DOMAIN, "BiasGelu", 1, bias_gelu); + REGISTER_OPERATOR_WITH_DOMAIN(MICROSOFT_DOMAIN, "EmbedLayerNormalization", 1, embed_layer_normalization); REGISTER_OPERATOR_WITH_DOMAIN(MICROSOFT_DOMAIN, "SkipLayerNormalization", 1, skip_layer_normalization); } diff --git a/ngraph/test/engines_util/ie_engines.cpp b/ngraph/test/engines_util/ie_engines.cpp index 3238b32c18ca14..e9b918c2e0ceb4 100644 --- a/ngraph/test/engines_util/ie_engines.cpp +++ b/ngraph/test/engines_util/ie_engines.cpp @@ -337,6 +337,9 @@ testing::AssertionResult test::IE_Engine::compare_results_with_tolerance_as_fp(c comparison_result = test::compare_with_tolerance(test_results.first, test_results.second, tolerance); break; } + case InferenceEngine::Precision::I32: + comparison_result = compare_blobs(computed_output_blob, expected_output_blob, 0); + break; default: comparison_result = testing::AssertionFailure() << "Unsupported data type encountered in " "'compare_results_with_tolerance_as_fp' method"; diff --git a/ngraph/test/engines_util/interpreter_engine.cpp b/ngraph/test/engines_util/interpreter_engine.cpp index 65f614d4c73fa8..6648f7f8fca6aa 100644 --- a/ngraph/test/engines_util/interpreter_engine.cpp +++ b/ngraph/test/engines_util/interpreter_engine.cpp @@ -124,6 +124,9 @@ testing::AssertionResult test::INTERPRETER_Engine::compare_results_with_toleranc case element::Type_t::f32: comparison_result = compare_with_fp_tolerance(expected_result_constant, result_tensor, tolerance); break; + case element::Type_t::i32: + comparison_result = compare_values(expected_result_constant, result_tensor, 0); + break; default: comparison_result = testing::AssertionFailure() << "Unsupported data type encountered in " "'compare_results_with_tolerance_as_fp' method"; diff --git a/ngraph/test/models/onnx/dynamic_shapes/embed_layer_normalization_dynamic_shapes.prototxt b/ngraph/test/models/onnx/dynamic_shapes/embed_layer_normalization_dynamic_shapes.prototxt new file mode 100644 index 00000000000000..577926c6d3114f --- /dev/null +++ b/ngraph/test/models/onnx/dynamic_shapes/embed_layer_normalization_dynamic_shapes.prototxt @@ -0,0 +1,186 @@ +ir_version: 6 +producer_name: "nGraph" +graph { + node { + input: "input_ids" + input: "segment_ids" + input: "word_embeddings" + input: "position_embeddings" + input: "segment_embeddings" + input: "gamma" + input: "beta" + input: "mask" + output: "output" + output: "mask_index" + name: "EmbedLayerNormalization_1" + op_type: "EmbedLayerNormalization" + attribute { + name: "epsilon" + f: 9.999999960041972e-13 + type: FLOAT + } + domain: "com.microsoft" + } + name: "graph" + input { + name: "input_ids" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_param: "batch_size" + } + dim { + dim_param: "seq_len" + } + } + } + } + } + input { + name: "segment_ids" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_param: "batch_size" + } + dim { + dim_param: "seq_len" + } + } + } + } + } + input { + name: "word_embeddings" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_param: "word_embed_len" + } + dim { + dim_value: 5 + } + } + } + } + } + input { + name: "position_embeddings" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_param: "pos_embed_len" + } + dim { + dim_value: 5 + } + } + } + } + } + input { + name: "segment_embeddings" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_param: "segment_embed_len" + } + dim { + dim_value: 5 + } + } + } + } + } + input { + name: "gamma" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 5 + } + } + } + } + } + input { + name: "beta" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 5 + } + } + } + } + } + input { + name: "mask" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_param: "batch_size" + } + dim { + dim_param: "seq_len" + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_param: "batch_size" + } + dim { + dim_param: "seq_len" + } + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "mask_index" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_param: "batch_size" + } + } + } + } + } +} +opset_import { + version: 11 +} +opset_import { + domain: "com.microsoft" + version: 1 +} diff --git a/ngraph/test/models/onnx/embed_layer_normalization.prototxt b/ngraph/test/models/onnx/embed_layer_normalization.prototxt new file mode 100644 index 00000000000000..1cd1bfcc1b3d70 --- /dev/null +++ b/ngraph/test/models/onnx/embed_layer_normalization.prototxt @@ -0,0 +1,187 @@ +ir_version: 6 +producer_name: "nGraph" +graph { + node { + input: "input_ids" + input: "" + input: "word_embeddings" + input: "position_embeddings" + input: "" + input: "gamma" + input: "beta" + output: "output" + name: "EmbedLayerNormalization_1" + op_type: "EmbedLayerNormalization" + attribute { + name: "epsilon" + f: 9.999999960041972e-13 + type: FLOAT + } + domain: "com.microsoft" + } + name: "graph" + initializer { + dims: 10 + dims: 5 + data_type: 1 + name: "word_embeddings" + float_data: 0.01326417364180088 + float_data: -0.017005326226353645 + float_data: 0.021556973457336426 + float_data: -0.079218357801437378 + float_data: -0.019958715885877609 + float_data: 0.066062852740287781 + float_data: -0.063465960323810577 + float_data: -0.036202378571033478 + float_data: -0.038673330098390579 + float_data: -0.050637193024158478 + float_data: 0.0024814880453050137 + float_data: -0.017267324030399323 + float_data: -0.0047671985812485218 + float_data: -0.014202062971889973 + float_data: 0.10090816766023636 + float_data: 0.044896259903907776 + float_data: 0.015443948097527027 + float_data: -0.0010053194127976894 + float_data: 0.071923978626728058 + float_data: 0.01173736434429884 + float_data: 0.034053854644298553 + float_data: -0.037060577422380447 + float_data: 0.01355923805385828 + float_data: 0.054467327892780304 + float_data: 0.088897556066513062 + float_data: 0.019563071429729462 + float_data: 0.025579970329999924 + float_data: -0.032200627028942108 + float_data: -0.0083356937393546104 + float_data: -0.10528338700532913 + float_data: 0.04967513307929039 + float_data: -0.093638911843299866 + float_data: 0.0018587876111268997 + float_data: 0.01037109550088644 + float_data: -0.011854520998895168 + float_data: 0.035907052457332611 + float_data: -0.061639595776796341 + float_data: -0.070428818464279175 + float_data: 0.080737568438053131 + float_data: -0.014098187908530235 + float_data: -0.066207133233547211 + float_data: 0.078362509608268738 + float_data: -0.021088391542434692 + float_data: -0.022340660914778709 + float_data: -0.065533898770809174 + float_data: -0.022695079445838928 + float_data: 0.01550679374486208 + float_data: -0.022843297570943832 + float_data: 0.044251278042793274 + float_data: -0.0071350894868373871 + } + initializer { + dims: 8 + dims: 5 + data_type: 1 + name: "position_embeddings" + float_data: 0.11355137079954147 + float_data: 0.048468157649040222 + float_data: 0.053486518561840057 + float_data: 0.01513370219618082 + float_data: 0.14626613259315491 + float_data: -0.18863441050052643 + float_data: 0.10133393853902817 + float_data: 0.098319537937641144 + float_data: 0.070722959935665131 + float_data: -0.018062451854348183 + float_data: -0.018210677430033684 + float_data: 0.018454158678650856 + float_data: 0.025413623079657555 + float_data: -0.017915787175297737 + float_data: 0.088725067675113678 + float_data: -0.10261145234107971 + float_data: -0.16650274395942688 + float_data: 0.087947741150856018 + float_data: -0.072966478765010834 + float_data: -0.072863951325416565 + float_data: -0.057195741683244705 + float_data: 0.052380021661520004 + float_data: 0.150204136967659 + float_data: 0.036691628396511078 + float_data: -0.055858571082353592 + float_data: 0.013746094889938831 + float_data: -0.041797593235969543 + float_data: 0.036348219960927963 + float_data: 0.032991457730531693 + float_data: -0.031414791941642761 + float_data: -0.026756083592772484 + float_data: -0.077081479132175446 + float_data: 0.039385091513395309 + float_data: -0.028280897065997124 + float_data: -0.039638441056013107 + float_data: 0.1023884043097496 + float_data: -0.038734495639801025 + float_data: 0.034112773835659027 + float_data: -0.024975193664431572 + float_data: -0.061074573546648026 + } + initializer { + dims: 5 + data_type: 1 + name: "gamma" + float_data: 0.037749473005533218 + float_data: -0.10285304486751556 + float_data: -0.030169183388352394 + float_data: -0.02105225995182991 + float_data: 0.11735564470291138 + } + initializer { + dims: 5 + data_type: 1 + name: "beta" + float_data: -0.058927357196807861 + float_data: -0.019592402502894402 + float_data: 0.0062640579417347908 + float_data: -0.010709371417760849 + float_data: -0.010058049112558365 + } + input { + name: "input_ids" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 8 + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 8 + } + dim { + dim_value: 5 + } + } + } + } + } +} +opset_import { + version: 11 +} +opset_import { + domain: "com.microsoft" + version: 1 +} diff --git a/ngraph/test/models/onnx/embed_layer_normalization_with_segment_embedding.prototxt b/ngraph/test/models/onnx/embed_layer_normalization_with_segment_embedding.prototxt new file mode 100644 index 00000000000000..36b7a1deaaa870 --- /dev/null +++ b/ngraph/test/models/onnx/embed_layer_normalization_with_segment_embedding.prototxt @@ -0,0 +1,239 @@ +ir_version: 6 +producer_name: "nGraph" +graph { + node { + input: "input_ids" + input: "segment_ids" + input: "word_embeddings" + input: "position_embeddings" + input: "segment_embeddings" + input: "gamma" + input: "beta" + output: "output" + output: "mask_index" + name: "EmbedLayerNormalization_1" + op_type: "EmbedLayerNormalization" + attribute { + name: "epsilon" + f: 9.999999960041972e-13 + type: FLOAT + } + domain: "com.microsoft" + } + name: "graph" + initializer { + dims: 10 + dims: 5 + data_type: 1 + name: "word_embeddings" + float_data: 0.01326417364180088 + float_data: -0.017005326226353645 + float_data: 0.021556973457336426 + float_data: -0.079218357801437378 + float_data: -0.019958715885877609 + float_data: 0.066062852740287781 + float_data: -0.063465960323810577 + float_data: -0.036202378571033478 + float_data: -0.038673330098390579 + float_data: -0.050637193024158478 + float_data: 0.0024814880453050137 + float_data: -0.017267324030399323 + float_data: -0.0047671985812485218 + float_data: -0.014202062971889973 + float_data: 0.10090816766023636 + float_data: 0.044896259903907776 + float_data: 0.015443948097527027 + float_data: -0.0010053194127976894 + float_data: 0.071923978626728058 + float_data: 0.01173736434429884 + float_data: 0.034053854644298553 + float_data: -0.037060577422380447 + float_data: 0.01355923805385828 + float_data: 0.054467327892780304 + float_data: 0.088897556066513062 + float_data: 0.019563071429729462 + float_data: 0.025579970329999924 + float_data: -0.032200627028942108 + float_data: -0.0083356937393546104 + float_data: -0.10528338700532913 + float_data: 0.04967513307929039 + float_data: -0.093638911843299866 + float_data: 0.0018587876111268997 + float_data: 0.01037109550088644 + float_data: -0.011854520998895168 + float_data: 0.035907052457332611 + float_data: -0.061639595776796341 + float_data: -0.070428818464279175 + float_data: 0.080737568438053131 + float_data: -0.014098187908530235 + float_data: -0.066207133233547211 + float_data: 0.078362509608268738 + float_data: -0.021088391542434692 + float_data: -0.022340660914778709 + float_data: -0.065533898770809174 + float_data: -0.022695079445838928 + float_data: 0.01550679374486208 + float_data: -0.022843297570943832 + float_data: 0.044251278042793274 + float_data: -0.0071350894868373871 + } + initializer { + dims: 8 + dims: 5 + data_type: 1 + name: "position_embeddings" + float_data: 0.11355137079954147 + float_data: 0.048468157649040222 + float_data: 0.053486518561840057 + float_data: 0.01513370219618082 + float_data: 0.14626613259315491 + float_data: -0.18863441050052643 + float_data: 0.10133393853902817 + float_data: 0.098319537937641144 + float_data: 0.070722959935665131 + float_data: -0.018062451854348183 + float_data: -0.018210677430033684 + float_data: 0.018454158678650856 + float_data: 0.025413623079657555 + float_data: -0.017915787175297737 + float_data: 0.088725067675113678 + float_data: -0.10261145234107971 + float_data: -0.16650274395942688 + float_data: 0.087947741150856018 + float_data: -0.072966478765010834 + float_data: -0.072863951325416565 + float_data: -0.057195741683244705 + float_data: 0.052380021661520004 + float_data: 0.150204136967659 + float_data: 0.036691628396511078 + float_data: -0.055858571082353592 + float_data: 0.013746094889938831 + float_data: -0.041797593235969543 + float_data: 0.036348219960927963 + float_data: 0.032991457730531693 + float_data: -0.031414791941642761 + float_data: -0.026756083592772484 + float_data: -0.077081479132175446 + float_data: 0.039385091513395309 + float_data: -0.028280897065997124 + float_data: -0.039638441056013107 + float_data: 0.1023884043097496 + float_data: -0.038734495639801025 + float_data: 0.034112773835659027 + float_data: -0.024975193664431572 + float_data: -0.061074573546648026 + } + initializer { + dims: 3 + dims: 5 + data_type: 1 + name: "segment_embeddings" + float_data: -0.027431340888142586 + float_data: -0.01666862890124321 + float_data: -0.052050836384296417 + float_data: -0.074926018714904785 + float_data: 0.0045464779250323772 + float_data: 0.054949179291725159 + float_data: 0.046781986951828003 + float_data: 0.065758734941482544 + float_data: -0.036851223558187485 + float_data: -0.041801471263170242 + float_data: 0.025191636756062508 + float_data: -0.046526473015546799 + float_data: 0.027152393013238907 + float_data: 0.026372035965323448 + float_data: -0.020972840487957001 + } + initializer { + dims: 5 + data_type: 1 + name: "gamma" + float_data: 0.037749473005533218 + float_data: -0.10285304486751556 + float_data: -0.030169183388352394 + float_data: -0.02105225995182991 + float_data: 0.11735564470291138 + } + initializer { + dims: 5 + data_type: 1 + name: "beta" + float_data: -0.058927357196807861 + float_data: -0.019592402502894402 + float_data: 0.0062640579417347908 + float_data: -0.010709371417760849 + float_data: -0.010058049112558365 + } + input { + name: "input_ids" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "segment_ids" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 8 + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 8 + } + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "mask_index" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 3 + } + } + } + } + } + +} +opset_import { + version: 11 +} +opset_import { + domain: "com.microsoft" + version: 1 +} diff --git a/ngraph/test/models/onnx/embed_layer_normalization_with_segment_embedding_and_mask.prototxt b/ngraph/test/models/onnx/embed_layer_normalization_with_segment_embedding_and_mask.prototxt new file mode 100644 index 00000000000000..1181538c82ccaa --- /dev/null +++ b/ngraph/test/models/onnx/embed_layer_normalization_with_segment_embedding_and_mask.prototxt @@ -0,0 +1,256 @@ +ir_version: 6 +producer_name: "nGraph" +graph { + node { + input: "input_ids" + input: "segment_ids" + input: "word_embeddings" + input: "position_embeddings" + input: "segment_embeddings" + input: "gamma" + input: "beta" + input: "mask" + output: "output" + output: "mask_index" + name: "EmbedLayerNormalization_1" + op_type: "EmbedLayerNormalization" + attribute { + name: "epsilon" + f: 9.999999960041972e-13 + type: FLOAT + } + domain: "com.microsoft" + } + name: "graph" + initializer { + dims: 10 + dims: 5 + data_type: 1 + name: "word_embeddings" + float_data: 0.01326417364180088 + float_data: -0.017005326226353645 + float_data: 0.021556973457336426 + float_data: -0.079218357801437378 + float_data: -0.019958715885877609 + float_data: 0.066062852740287781 + float_data: -0.063465960323810577 + float_data: -0.036202378571033478 + float_data: -0.038673330098390579 + float_data: -0.050637193024158478 + float_data: 0.0024814880453050137 + float_data: -0.017267324030399323 + float_data: -0.0047671985812485218 + float_data: -0.014202062971889973 + float_data: 0.10090816766023636 + float_data: 0.044896259903907776 + float_data: 0.015443948097527027 + float_data: -0.0010053194127976894 + float_data: 0.071923978626728058 + float_data: 0.01173736434429884 + float_data: 0.034053854644298553 + float_data: -0.037060577422380447 + float_data: 0.01355923805385828 + float_data: 0.054467327892780304 + float_data: 0.088897556066513062 + float_data: 0.019563071429729462 + float_data: 0.025579970329999924 + float_data: -0.032200627028942108 + float_data: -0.0083356937393546104 + float_data: -0.10528338700532913 + float_data: 0.04967513307929039 + float_data: -0.093638911843299866 + float_data: 0.0018587876111268997 + float_data: 0.01037109550088644 + float_data: -0.011854520998895168 + float_data: 0.035907052457332611 + float_data: -0.061639595776796341 + float_data: -0.070428818464279175 + float_data: 0.080737568438053131 + float_data: -0.014098187908530235 + float_data: -0.066207133233547211 + float_data: 0.078362509608268738 + float_data: -0.021088391542434692 + float_data: -0.022340660914778709 + float_data: -0.065533898770809174 + float_data: -0.022695079445838928 + float_data: 0.01550679374486208 + float_data: -0.022843297570943832 + float_data: 0.044251278042793274 + float_data: -0.0071350894868373871 + } + initializer { + dims: 8 + dims: 5 + data_type: 1 + name: "position_embeddings" + float_data: 0.11355137079954147 + float_data: 0.048468157649040222 + float_data: 0.053486518561840057 + float_data: 0.01513370219618082 + float_data: 0.14626613259315491 + float_data: -0.18863441050052643 + float_data: 0.10133393853902817 + float_data: 0.098319537937641144 + float_data: 0.070722959935665131 + float_data: -0.018062451854348183 + float_data: -0.018210677430033684 + float_data: 0.018454158678650856 + float_data: 0.025413623079657555 + float_data: -0.017915787175297737 + float_data: 0.088725067675113678 + float_data: -0.10261145234107971 + float_data: -0.16650274395942688 + float_data: 0.087947741150856018 + float_data: -0.072966478765010834 + float_data: -0.072863951325416565 + float_data: -0.057195741683244705 + float_data: 0.052380021661520004 + float_data: 0.150204136967659 + float_data: 0.036691628396511078 + float_data: -0.055858571082353592 + float_data: 0.013746094889938831 + float_data: -0.041797593235969543 + float_data: 0.036348219960927963 + float_data: 0.032991457730531693 + float_data: -0.031414791941642761 + float_data: -0.026756083592772484 + float_data: -0.077081479132175446 + float_data: 0.039385091513395309 + float_data: -0.028280897065997124 + float_data: -0.039638441056013107 + float_data: 0.1023884043097496 + float_data: -0.038734495639801025 + float_data: 0.034112773835659027 + float_data: -0.024975193664431572 + float_data: -0.061074573546648026 + } + initializer { + dims: 3 + dims: 5 + data_type: 1 + name: "segment_embeddings" + float_data: -0.027431340888142586 + float_data: -0.01666862890124321 + float_data: -0.052050836384296417 + float_data: -0.074926018714904785 + float_data: 0.0045464779250323772 + float_data: 0.054949179291725159 + float_data: 0.046781986951828003 + float_data: 0.065758734941482544 + float_data: -0.036851223558187485 + float_data: -0.041801471263170242 + float_data: 0.025191636756062508 + float_data: -0.046526473015546799 + float_data: 0.027152393013238907 + float_data: 0.026372035965323448 + float_data: -0.020972840487957001 + } + initializer { + dims: 5 + data_type: 1 + name: "gamma" + float_data: 0.037749473005533218 + float_data: -0.10285304486751556 + float_data: -0.030169183388352394 + float_data: -0.02105225995182991 + float_data: 0.11735564470291138 + } + initializer { + dims: 5 + data_type: 1 + name: "beta" + float_data: -0.058927357196807861 + float_data: -0.019592402502894402 + float_data: 0.0062640579417347908 + float_data: -0.010709371417760849 + float_data: -0.010058049112558365 + } + input { + name: "input_ids" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "segment_ids" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "mask" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 8 + } + } + } + } + } + + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 8 + } + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "mask_index" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 3 + } + } + } + } + } +} +opset_import { + version: 11 +} +opset_import { + domain: "com.microsoft" + version: 1 +} diff --git a/ngraph/test/onnx/onnx_import_com_microsoft.in.cpp b/ngraph/test/onnx/onnx_import_com_microsoft.in.cpp index 43f74e64b06703..63611843a27dfd 100644 --- a/ngraph/test/onnx/onnx_import_com_microsoft.in.cpp +++ b/ngraph/test/onnx/onnx_import_com_microsoft.in.cpp @@ -73,7 +73,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_skip_layer_normalization_with_gamma_beta test_case.add_input(input); test_case.add_input(skip); test_case.add_expected_output(expected); - test_case.run(5); + test_case.run_with_tolerance_as_fp(); } NGRAPH_TEST(${BACKEND_NAME}, onnx_model_skip_layer_normalization_with_gamma_beta) { @@ -99,7 +99,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_skip_layer_normalization_with_gamma_beta test_case.add_input(input); test_case.add_input(skip); test_case.add_expected_output(expected); - test_case.run(7); + test_case.run_with_tolerance_as_fp(); } NGRAPH_TEST(${BACKEND_NAME}, onnx_model_skip_layer_normalization_with_gamma) { @@ -125,7 +125,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_skip_layer_normalization_with_gamma) { test_case.add_input(input); test_case.add_input(skip); test_case.add_expected_output(expected); - test_case.run(6); + test_case.run_with_tolerance_as_fp(); } NGRAPH_TEST(${BACKEND_NAME}, onnx_model_skip_layer_normalization_dynamic_shapes) { @@ -173,5 +173,219 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_skip_layer_normalization_dynamic_shapes) test_case.add_input(Shape{4}, beta); test_case.add_input(Shape{4}, bias); test_case.add_expected_output(Shape{3, 2, 4}, expected); - test_case.run(7); + test_case.run_with_tolerance_as_fp(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_embed_layer_normalization) { + const auto function = + onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/embed_layer_normalization.onnx")); + + std::vector input_ids = { + 8, 1, 5, 9, 8, 9, 4, 3, 0, 3, 5, 0, 2, 3, 8, 1, 3, 3, 3, 7, 0, 1, 9, 9, + }; + std::vector expected_output = { + -0.06615843, -0.18040463, 0.02199928, 0.01868065, 0.05397778, -0.11761580, -0.09138932, -0.02506775, + -0.02368510, -0.10373901, -0.05551499, -0.20972314, 0.01365213, 0.01132561, -0.08603337, -0.08906764, + 0.09692993, -0.04444099, -0.02037602, -0.03453060, -0.10214549, -0.13331436, -0.02665862, -0.01228805, + -0.14232540, -0.07032782, 0.05511986, -0.00120272, -0.04875736, -0.13051267, -0.05709254, 0.17854357, + -0.01759873, -0.01819968, 0.07573269, 0.00557164, 0.06232717, 0.00530490, -0.01565807, -0.14841977, + -0.02299280, 0.02038561, -0.00049481, 0.02575402, 0.10081697, -0.12517214, -0.09316762, -0.00974943, + -0.03093284, -0.06309240, -0.05551499, -0.20972314, 0.01365213, 0.01132561, -0.08603337, -0.06176658, + 0.08304203, -0.05025182, 0.00383657, -0.02288112, -0.11407227, -0.01386134, -0.04411830, -0.00537948, + 0.00164397, -0.03739140, 0.09941526, 0.00333974, -0.04251949, -0.12992151, -0.09509478, -0.11811313, + -0.03307065, -0.00866115, -0.15162414, 0.01106802, 0.06037656, 0.00035292, -0.00223284, -0.11215645, + -0.01390734, 0.07064321, 0.04028325, -0.00290875, 0.12875907, -0.12517214, -0.09316762, -0.00974943, + -0.03093284, -0.06309240, -0.08723789, 0.03130914, 0.03131931, -0.01526242, 0.20811458, -0.05696163, + 0.16304255, -0.02407495, -0.02955675, -0.03086288, -0.08130091, -0.05001551, -0.04875683, 0.00143666, + -0.12153473, -0.00018507, 0.10957482, -0.00416618, -0.01612359, -0.11605026, -0.08593204, 0.09055272, + -0.03054028, -0.03603891, -0.08479506, -0.00034568, 0.03713699, 0.00163411, -0.01738501, -0.18267182, + }; + + auto test_case = test::TestCase(function); + test_case.add_input(input_ids); + test_case.add_expected_output(expected_output); + test_case.run_with_tolerance_as_fp(1e-7f); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_embed_layer_normalization_with_segment_embedding) { + const auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/embed_layer_normalization_with_segment_embedding.onnx")); + + std::vector input_ids = { + 8, 1, 5, 9, 8, 9, 4, 3, 0, 3, 5, 0, 2, 3, 8, 1, 3, 3, 3, 7, 0, 1, 9, 9, + }; + std::vector segment_ids = { + 0, 2, 0, 2, 2, 0, 2, 0, 0, 0, 1, 1, 2, 0, 0, 1, 0, 1, 2, 2, 0, 1, 1, 1, + }; + std::vector expected_output = { + -0.06044213, -0.14845914, 0.02457689, 0.02091519, 0.09514004, -0.10280035, -0.02087995, -0.03323204, + -0.02967127, -0.13447416, -0.05191760, -0.16518904, 0.02340531, 0.02176395, 0.04972410, -0.07360736, + 0.12192874, -0.04081530, -0.02338044, -0.05671440, -0.09475864, -0.08944942, -0.03362993, -0.01683486, + -0.16770349, -0.07382569, 0.06230322, 0.02215859, -0.05212611, -0.03934773, -0.04748865, 0.18134241, + -0.01965741, -0.02202452, 0.01973994, 0.01575558, 0.04300199, 0.01436110, -0.00198062, -0.09065692, + -0.02923042, -0.00748686, 0.00717049, 0.02638642, 0.12174864, -0.12973398, -0.11872391, -0.00549398, + -0.02386289, -0.02210563, -0.03590920, -0.13728066, -0.01337939, 0.01538021, -0.14687485, -0.05033565, + 0.03818212, -0.04939338, 0.00961064, -0.07407621, -0.09624685, 0.05594898, -0.04948713, -0.01305631, + -0.03779668, -0.01469170, 0.12346989, 0.02082030, -0.03449103, -0.06029151, -0.09300473, -0.16308543, + -0.02370042, 0.01066893, -0.06523034, 0.00497636, 0.01933458, -0.00900802, 0.00430878, -0.13999483, + -0.02377289, 0.01760014, 0.03896973, 0.00831112, 0.15634246, -0.11109130, -0.11997811, -0.02304414, + -0.01989413, -0.12763791, -0.05698400, 0.17125534, 0.00499324, -0.02953288, 0.09178342, -0.05001877, + 0.16157132, -0.02312993, -0.02932195, -0.04914058, -0.07994118, -0.07199102, -0.04517454, 0.01249476, + -0.07525793, -0.00207180, 0.03993115, -0.01676321, -0.00214832, -0.16074482, -0.05012497, -0.00552153, + -0.04302063, -0.00549224, -0.18399858, -0.00767871, -0.02209404, -0.01383207, -0.00082931, -0.19533031, + }; + + std::vector expected_mask_index = { + 0, + 0, + 0, + }; + + auto test_case = test::TestCase(function); + test_case.add_input(input_ids); + test_case.add_input(segment_ids); + test_case.add_expected_output(expected_output); + test_case.add_expected_output(expected_mask_index); + test_case.run_with_tolerance_as_fp(1e-7); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_embed_layer_normalization_with_segment_embedding_and_mask) { + const auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/embed_layer_normalization_with_segment_embedding_and_mask.onnx")); + + std::vector input_ids = { + 8, 1, 5, 9, 8, 9, 4, 3, 0, 3, 5, 0, 2, 3, 8, 1, 3, 3, 3, 7, 0, 1, 9, 9, + }; + std::vector segment_ids = { + 0, 2, 0, 2, 2, 0, 2, 0, 0, 0, 1, 1, 2, 0, 0, 1, 0, 1, 2, 2, 0, 1, 1, 1, + }; + std::vector mask = { + 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, + }; + std::vector expected_output = { + -0.06044213, -0.14845914, 0.02457689, 0.02091519, 0.09514004, -0.10280035, -0.02087995, -0.03323204, + -0.02967127, -0.13447416, -0.05191760, -0.16518904, 0.02340531, 0.02176395, 0.04972410, -0.07360736, + 0.12192874, -0.04081530, -0.02338044, -0.05671440, -0.09475864, -0.08944942, -0.03362993, -0.01683486, + -0.16770349, -0.07382569, 0.06230322, 0.02215859, -0.05212611, -0.03934773, -0.04748865, 0.18134241, + -0.01965741, -0.02202452, 0.01973994, 0.01575558, 0.04300199, 0.01436110, -0.00198062, -0.09065692, + -0.02923042, -0.00748686, 0.00717049, 0.02638642, 0.12174864, -0.12973398, -0.11872391, -0.00549398, + -0.02386289, -0.02210563, -0.03590920, -0.13728066, -0.01337939, 0.01538021, -0.14687485, -0.05033565, + 0.03818212, -0.04939338, 0.00961064, -0.07407621, -0.09624685, 0.05594898, -0.04948713, -0.01305631, + -0.03779668, -0.01469170, 0.12346989, 0.02082030, -0.03449103, -0.06029151, -0.09300473, -0.16308543, + -0.02370042, 0.01066893, -0.06523034, 0.00497636, 0.01933458, -0.00900802, 0.00430878, -0.13999483, + -0.02377289, 0.01760014, 0.03896973, 0.00831112, 0.15634246, -0.11109130, -0.11997811, -0.02304414, + -0.01989413, -0.12763791, -0.05698400, 0.17125534, 0.00499324, -0.02953288, 0.09178342, -0.05001877, + 0.16157132, -0.02312993, -0.02932195, -0.04914058, -0.07994118, -0.07199102, -0.04517454, 0.01249476, + -0.07525793, -0.00207180, 0.03993115, -0.01676321, -0.00214832, -0.16074482, -0.05012497, -0.00552153, + -0.04302063, -0.00549224, -0.18399858, -0.00767871, -0.02209404, -0.01383207, -0.00082931, -0.19533031, + }; + std::vector expected_mask_index = { + 5, + 3, + 4, + }; + + auto test_case = test::TestCase(function); + test_case.add_input(input_ids); + test_case.add_input(segment_ids); + test_case.add_input(mask); + test_case.add_expected_output(expected_output); + test_case.add_expected_output(expected_mask_index); + test_case.run_with_tolerance_as_fp(1e-7); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_embed_layer_normalization_dynamic_shapes) { + const auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/dynamic_shapes/embed_layer_normalization_dynamic_shapes.onnx")); + + std::vector input_ids = { + 8, 1, 5, 9, 8, 9, 4, 3, 0, 3, 5, 0, 2, 3, 8, 1, 3, 3, 3, 7, 0, 1, 9, 9, + }; + std::vector segment_ids = { + 0, 2, 0, 2, 2, 0, 2, 0, 0, 0, 1, 1, 2, 0, 0, 1, 0, 1, 2, 2, 0, 1, 1, 1, + }; + std::vector word_embeddings = { + 0.96980906, 0.65314001, 0.17090958, 0.35815218, 0.75068617, 0.60783064, 0.32504722, 0.03842543, 0.63427407, + 0.95894927, 0.65279031, 0.63505888, 0.99529958, 0.58185035, 0.41436860, 0.47469750, 0.62351012, 0.33800763, + 0.67475230, 0.31720173, 0.77834547, 0.94957107, 0.66252685, 0.01357164, 0.62284607, 0.67365962, 0.97194499, + 0.87819350, 0.50962436, 0.05571469, 0.45115921, 0.01998767, 0.44171092, 0.97958672, 0.35944447, 0.48089352, + 0.68866116, 0.88047588, 0.91823548, 0.21682213, 0.56518888, 0.86510259, 0.50896895, 0.91672295, 0.92115760, + 0.08311249, 0.27771857, 0.00935670, 0.84234208, 0.64717412, + }; + std::vector position_embeddings = { + 0.84138614, 0.26473016, 0.39782074, 0.55282146, 0.16494046, 0.36980811, 0.14644176, 0.56961840, + 0.70373726, 0.28847644, 0.43328807, 0.75610667, 0.39609829, 0.89603841, 0.63892108, 0.89155442, + 0.68005556, 0.44919774, 0.97857094, 0.11620191, 0.76702368, 0.41182014, 0.67543906, 0.24979627, + 0.31321833, 0.96541619, 0.58846509, 0.65966839, 0.53320622, 0.23053302, 0.39486930, 0.61880857, + 0.47486752, 0.47013220, 0.71607453, 0.28799102, 0.38346222, 0.74916983, 0.87845218, 0.10286336, + }; + std::vector segment_embeddings = { + 0.09237389, + 0.35404667, + 0.55181628, + 0.03362509, + 0.96896178, + 0.32099724, + 0.22126268, + 0.14126390, + 0.09725992, + 0.98404223, + 0.26034093, + 0.53702253, + 0.44792616, + 0.09956909, + 0.35231167, + }; + std::vector gamma = { + 0.46924916, + 0.84114015, + 0.90464777, + 0.03755938, + 0.50831544, + }; + std::vector beta = { + 0.16684751, + 0.77905101, + 0.86493331, + 0.41139671, + 0.13997258, + }; + std::vector mask = { + 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, + }; + std::vector expected_output = { + -0.04089922, 0.35108989, 0.30442458, 0.39546335, 1.15422225, 0.10419128, -0.19301927, 0.01070970, + 0.43977541, 0.89119899, -0.51436460, 1.99256825, 1.41077507, 0.38642293, 0.17583044, 0.03320138, + 1.16508031, -0.24356931, 0.47440714, -0.17844005, 0.20463173, 1.90038323, 1.14138567, 0.34504607, + 0.16403235, -0.24976699, 0.29362509, 0.34502214, 0.41751838, 1.09390712, 0.12354189, 1.83025289, + 1.05569196, 0.34413773, 0.35469764, -0.69760042, 0.76338542, 1.75443077, 0.44126555, 0.18181801, + 0.73277575, 0.45443264, 0.17068321, 0.36591727, 0.72869974, -0.56090516, 0.14415455, 1.47314119, + 0.42908576, 0.73084539, -0.22373237, 2.26550221, 0.05606699, 0.39417523, 0.35234636, 0.78569502, + 0.77521765, -0.65131050, 0.40168875, 0.45527256, 0.38715565, 0.98521245, 2.21446753, 0.36345237, + -0.33269632, 0.36558092, 1.36846578, 1.37523413, 0.33698002, 0.28889543, -0.40639281, 1.01643157, + 0.59668219, 0.39197800, 1.03101778, 0.02551098, -0.03612846, -0.01371557, 0.43444607, 0.96746695, + 0.60583955, -0.10362893, 0.40574494, 0.38046724, 0.87445319, -0.00880148, -0.15437943, 0.08118075, + 0.44650543, 0.85956848, -0.27865338, 2.10837507, 0.04798460, 0.43948367, -0.10185169, 0.19978794, + 1.32323360, 1.20525467, 0.44288942, -0.84200430, 0.52563053, 0.69949460, 0.73987913, 0.34668452, + 0.74545687, 0.57696682, 0.22452033, -0.27099937, 0.39649010, 0.87083614, -0.18965788, 0.58206403, + -0.08108193, 0.42067638, 1.05117214, -0.34287399, 0.20424896, 0.27994895, 0.46011117, 0.70890665, + }; + std::vector expected_mask_index = { + 6, + 5, + 5, + }; + + auto test_case = test::TestCase(function); + test_case.add_input(Shape{3, 8}, input_ids); + test_case.add_input(Shape{3, 8}, segment_ids); + test_case.add_input(Shape{10, 5}, word_embeddings); + test_case.add_input(Shape{8, 5}, position_embeddings); + test_case.add_input(Shape{3, 5}, segment_embeddings); + test_case.add_input(Shape{5}, gamma); + test_case.add_input(Shape{5}, beta); + test_case.add_input(Shape{3, 8}, mask); + test_case.add_expected_output(Shape{3, 8, 5}, expected_output); + test_case.add_expected_output(Shape{3}, expected_mask_index); + test_case.run_with_tolerance_as_fp(1e-6); } diff --git a/ngraph/test/runtime/ie/unit_test.manifest b/ngraph/test/runtime/ie/unit_test.manifest index fb73dc5038232b..e935acda5e30d7 100644 --- a/ngraph/test/runtime/ie/unit_test.manifest +++ b/ngraph/test/runtime/ie/unit_test.manifest @@ -1575,3 +1575,5 @@ IE_CPU.onnx_model_gather_float_2D_neg_indices # CPU plug-in doesn't support operation with dynamic rank onnx_model_skip_layer_normalization_dynamic_shapes +# Doesn't support op with dynamic shapes +onnx_model_embed_layer_normalization_dynamic_shapes From b6f59a09cdadff0c320edb14d6acd32034761fc4 Mon Sep 17 00:00:00 2001 From: Mateusz Tabaka Date: Wed, 6 Oct 2021 16:05:55 +0200 Subject: [PATCH 2/2] style changes --- .../embed_layer_normalization.cpp | 42 +++++++++++++------ 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/ngraph/frontend/onnx/frontend/src/op/com.microsoft/embed_layer_normalization.cpp b/ngraph/frontend/onnx/frontend/src/op/com.microsoft/embed_layer_normalization.cpp index d46690375d3e29..0616cc5403204a 100644 --- a/ngraph/frontend/onnx/frontend/src/op/com.microsoft/embed_layer_normalization.cpp +++ b/ngraph/frontend/onnx/frontend/src/op/com.microsoft/embed_layer_normalization.cpp @@ -14,18 +14,33 @@ namespace set_1 { OutputVector embed_layer_normalization(const Node& node) { auto nodes = node.get_ng_inputs(); auto num_nodes = nodes.size(); + NGRAPH_CHECK(num_nodes >= 7 && num_nodes <= 8, "EmbedLayerNormalization takes 7 or 8 inputs. Provided " + std::to_string(num_nodes)); NGRAPH_CHECK(nodes[0].get_element_type() == element::i32, "input_ids must have int32 type"); + + const auto& input_ids = nodes[0]; + const auto& segment_ids = nodes[1]; + const auto& word_embeddings = nodes[2]; + const auto& position_embeddings = nodes[3]; + const auto& segment_embeddings = nodes[4]; + const auto& gamma = nodes[5]; + const auto& beta = nodes[6]; + auto zero = default_opset::Constant::create(element::i32, Shape{1}, {0}); - auto word_embedding = std::make_shared(nodes[2], nodes[0], zero, 0); - auto input = std::make_shared(word_embedding, nodes[3]); - if (!ngraph::op::is_null(nodes[1])) { - NGRAPH_CHECK(!ngraph::op::is_null(nodes[4]), "segment_ids provided, but segment_embedding input is missing"); + std::shared_ptr input = std::make_shared(word_embeddings, input_ids, zero, 0); + input = std::make_shared(input, position_embeddings); + + // add segment embeddings if available + if (!ngraph::op::is_null(segment_ids)) { + NGRAPH_CHECK(!ngraph::op::is_null(segment_embeddings), + "segment_ids provided, but segment_embedding input is missing"); NGRAPH_CHECK(nodes[1].get_element_type() == element::i32, "segment_ids must have int32 type"); - auto segment_embedding = std::make_shared(nodes[4], nodes[1], zero, 0); - input = std::make_shared(input, segment_embedding); + auto gathered_segment_embeddings = + std::make_shared(segment_embeddings, segment_ids, zero, 0); + input = std::make_shared(input, gathered_segment_embeddings); } + float eps = node.get_attribute_value("epsilon"); // reduce over hidden_size // hidden_size dimension is 2 here, because the shape after Gather(word_embedding, input_ids) @@ -34,10 +49,12 @@ OutputVector embed_layer_normalization(const Node& node) { const auto reduction_axes = default_opset::Constant::create(element::i32, Shape{1}, {hidden_size_dim}); std::shared_ptr result = std::make_shared(input, reduction_axes, true, eps, ngraph::op::MVNEpsMode::INSIDE_SQRT); - // multiply by gamma - result = std::make_shared(result, nodes[5]); - // add beta - result = std::make_shared(result, nodes[6]); + + // result = gamma * result + beta + result = std::make_shared(result, gamma); + result = std::make_shared(result, beta); + + // compute mask_index output std::shared_ptr mask_index; if (num_nodes > 7 && !ngraph::op::is_null(nodes[7])) { NGRAPH_CHECK(nodes[7].get_element_type() == element::i32, "mask must have int32 type"); @@ -45,9 +62,8 @@ OutputVector embed_layer_normalization(const Node& node) { mask_index = std::make_shared(nodes[7], axis, false); } else { auto batch_size = std::make_shared(std::make_shared(nodes[0]), - zero, // indices - zero // axis - ); + zero, // indices + zero); // axis mask_index = std::make_shared(zero, batch_size); } return {result, mask_index};