From a1a69dbea94c2f411c8c33efdc85b9921aeb1432 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Wed, 26 Jun 2024 17:16:19 +0200 Subject: [PATCH] [PT FE] Support patching nn.Embedding and Conv1D for 16bit models (#25076) ### Details: - *item1* - *...* ### Tickets: - *CVS-143351* --- .../openvino/frontend/pytorch/patch_model.py | 33 +++++++++++++++---- src/frontends/pytorch/src/op/addmm.cpp | 28 ++++++++++++++++ src/frontends/pytorch/src/op/embedding.cpp | 13 +++++++- src/frontends/pytorch/src/op_table.cpp | 5 +++ .../py_frontend_tests/test_torch_frontend.py | 10 ++++-- 5 files changed, 78 insertions(+), 11 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/patch_model.py b/src/bindings/python/src/openvino/frontend/pytorch/patch_model.py index 8245a69ecd7f3f..56ee32fa8ca0c0 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/patch_model.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/patch_model.py @@ -30,6 +30,7 @@ def module_patcher(m, name): if extension: # The Trampoline class is instantiated for every module replacement, so we can use class members individually for each module. + class Trampoline(torch.autograd.Function): target_extension = extension original_module = m @@ -83,14 +84,32 @@ def unpatch_model(model, orig_forward_name): def __make_16bit_traceable(model: torch.nn.Module): - # Replace torch.nn.Linear with ModuleExtension and move other modules to fp32 - extensions = {torch.nn.Linear: ModuleExtension( - torch.nn.Linear, - "aten::linear", - evaluate=lambda module, *args, **kwargs: torch.ones( - list(args[0].shape[:-1]) + [module.out_features], dtype=torch.float32) * 0.5, - convert=lambda module, target_op, *args, **kwargs: target_op(args[0], module.weight, module.bias)) + """ + Prepare a 16-bit PyTorch model for tracing with OpenVINO. + - Replace known list of modules with ModuleExtension. + - Convert other modules with weights to FP32. + """ + extensions = { + torch.nn.Linear: ModuleExtension( + torch.nn.Linear, "ov_ext::linear", + evaluate=lambda module, *args, **kwargs: torch.full( + list(args[0].shape[:-1]) + [module.out_features], 0.5, dtype=torch.float32), + convert=lambda module, target_op, *args, **kwargs: target_op(args[0], module.weight, module.bias)), + torch.nn.Embedding: ModuleExtension( + torch.nn.Embedding, "ov_ext::embedding", + evaluate=lambda module, *args, **kwargs: torch.full( + list(args[0].shape) + [module.embedding_dim], 0.5, dtype=torch.float32), + convert=lambda module, target_op, *args, **kwargs: target_op(module.weight, args[0], module.padding_idx, module.scale_grad_by_freq, module.sparse)), } + try: + from transformers.pytorch_utils import Conv1D + extensions[Conv1D] = ModuleExtension( + Conv1D, "ov_ext::conv1d", + evaluate=lambda module, *args, **kwargs: torch.full( + list(args[0].shape[:-1]) + [module.nf], 0.5, dtype=torch.float32), + convert=lambda module, target_op, *args, **kwargs: target_op(args[0], module.weight, module.bias)) + except: + pass patch_model(model, extensions, "_openvino_module_extension_patch_orig_forward") for _, module in model.named_modules(): diff --git a/src/frontends/pytorch/src/op/addmm.cpp b/src/frontends/pytorch/src/op/addmm.cpp index 2f43079a037405..522c5237ecd11b 100644 --- a/src/frontends/pytorch/src/op/addmm.cpp +++ b/src/frontends/pytorch/src/op/addmm.cpp @@ -4,9 +4,14 @@ #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/add.hpp" +#include "openvino/op/concat.hpp" #include "openvino/op/convert_like.hpp" +#include "openvino/op/gather.hpp" #include "openvino/op/matmul.hpp" #include "openvino/op/multiply.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scatter_elements_update.hpp" +#include "openvino/op/shape_of.hpp" #include "utils.hpp" namespace ov { @@ -58,6 +63,29 @@ OutputVector translate_addmm_fx(const NodeContext& context) { return {translate_addmm_common(context, beta, alpha)}; }; +OutputVector translate_conv1d_ext(const NodeContext& context) { + // not really a convolution, implemented based on + // https://github.com/huggingface/transformers/blob/0ed3ffcb4461a244b87781a24e5ebd0a78f98142/src/transformers/pytorch_utils.py#L84 + num_inputs_check(context, 3, 3); + auto x = context.get_input(0); + auto weight = context.get_input(1); + weight = context.mark_node(std::make_shared(weight, x)); + auto bias = context.get_input(2); + bias = context.mark_node(std::make_shared(bias, x)); + + auto neg_one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); + auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); + auto shape_x = context.mark_node(std::make_shared(x, element::i32)); + auto x_last_dim = context.mark_node(std::make_shared(shape_x, neg_one, zero)); + auto x_new_shape = context.mark_node(std::make_shared(OutputVector{neg_one, x_last_dim}, 0)); + + auto x_new = context.mark_node(std::make_shared(x, x_new_shape, false)); + auto mm = context.mark_node(std::make_shared(x_new, weight)); + auto addmm = context.mark_node(std::make_shared(bias, mm)); + auto size_out = context.mark_node(std::make_shared(shape_x, neg_one, neg_one, zero)); + return {context.mark_node(std::make_shared(addmm, size_out, false))}; +}; + } // namespace op } // namespace pytorch } // namespace frontend diff --git a/src/frontends/pytorch/src/op/embedding.cpp b/src/frontends/pytorch/src/op/embedding.cpp index 0bde94b432649e..8f80699ec007a9 100644 --- a/src/frontends/pytorch/src/op/embedding.cpp +++ b/src/frontends/pytorch/src/op/embedding.cpp @@ -20,7 +20,7 @@ OutputVector translate_embedding(const NodeContext& context) { auto data = context.get_input(0); auto indices = context.get_input(1); indices = context.mark_node(std::make_shared(indices, element::i32)); - // skip parameters 2, 3, 4 used only during trainig: + // skip parameters 2, 3, 4 used only during training: // padding_idx - if specified, the entries at padding_idx do not contribute to the gradient // scale_grad_by_freq - if given, this will scale gradients by the inverse of frequency of // the words in the mini-batch. @@ -29,6 +29,17 @@ OutputVector translate_embedding(const NodeContext& context) { return {context.mark_node(std::make_shared(data, indices, axis_0))}; }; +OutputVector translate_embedding_ext(const NodeContext& context) { + // used in 16-bit patching + num_inputs_check(context, 2, 5); + auto data = context.get_input(0); + data = context.mark_node(std::make_shared(data, element::f32)); + auto indices = context.get_input(1); + indices = context.mark_node(std::make_shared(indices, element::i32)); + auto axis_0 = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{}, {0})); + return {context.mark_node(std::make_shared(data, indices, axis_0))}; +}; + } // namespace op } // namespace pytorch } // namespace frontend diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 9cc73f854bbf6a..5e00395614f31d 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -57,6 +57,7 @@ OP_CONVERTER(translate_channel_shuffle); OP_CONVERTER(translate_clamp); OP_CONVERTER(translate_constant); OP_CONVERTER(translate_conv_transposend); +OP_CONVERTER(translate_conv1d_ext); OP_CONVERTER(translate_convnd); OP_CONVERTER(translate_convolution); OP_CONVERTER(translate_convolution_mode); @@ -72,6 +73,7 @@ OP_CONVERTER(translate_dot); OP_CONVERTER(translate_elu); OP_CONVERTER(translate_embedding); OP_CONVERTER(translate_embedding_bag); +OP_CONVERTER(translate_embedding_ext); OP_CONVERTER(translate_empty); OP_CONVERTER(translate_empty_like); OP_CONVERTER(translate_erf); @@ -702,6 +704,9 @@ const std::map get_supported_ops_ts() { {"aten::zero_", op::inplace_op}, {"aten::zeros", op::translate_zeros}, {"aten::zeros_like", op::translate_zeros_like}, + {"ov_ext::embedding", op::translate_embedding_ext}, + {"ov_ext::conv1d", op::translate_conv1d_ext}, + {"ov_ext::linear", op::translate_linear}, {"prim::Constant", op::translate_constant}, {"prim::device", op::translate_constant}, // prim::DictConstruct - Supported in limited set of patterns diff --git a/tests/layer_tests/py_frontend_tests/test_torch_frontend.py b/tests/layer_tests/py_frontend_tests/test_torch_frontend.py index 18b3a4776f686a..155b772d560222 100644 --- a/tests/layer_tests/py_frontend_tests/test_torch_frontend.py +++ b/tests/layer_tests/py_frontend_tests/test_torch_frontend.py @@ -687,16 +687,20 @@ def test_patched_16bit_model_converts(): from openvino.frontend.pytorch import patch_model from openvino import convert_model, compile_model import copy + from transformers.pytorch_utils import Conv1D class ModelWithLinear(torch.nn.Module): def __init__(self): super().__init__() self.branch1 = torch.nn.Sequential( - torch.nn.Linear(64, 32), torch.nn.ReLU() + torch.nn.Embedding(10, 64), + torch.nn.Linear(64, 32), + torch.nn.ReLU() ) self.branch2 = torch.nn.Sequential( - torch.nn.Linear(128, 64), torch.nn.ReLU() + Conv1D(256, 128), + torch.nn.Linear(256, 64), torch.nn.ReLU() ) self.buffer = torch.ones(32) @@ -705,7 +709,7 @@ def forward(self, x1, x2): out2 = self.branch2(x2) return (out1 + self.buffer, out2) - example = (torch.randn(32, 64), torch.randn(32, 128)) + example = (torch.randint(0, 10, [32, 64]), torch.randn(32, 128)) model_ref = ModelWithLinear() with torch.no_grad(): res_ref = model_ref(*example)