Skip to content

Commit

Permalink
[PT FE] Support patching nn.Embedding and Conv1D for 16bit models (op…
Browse files Browse the repository at this point in the history
…envinotoolkit#25076)

### Details:
 - *item1*
 - *...*

### Tickets:
 - *CVS-143351*
  • Loading branch information
mvafin authored Jun 26, 2024
1 parent ad4eb09 commit ead5235
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 11 deletions.
33 changes: 26 additions & 7 deletions src/bindings/python/src/openvino/frontend/pytorch/patch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def module_patcher(m, name):

if extension:
# The Trampoline class is instantiated for every module replacement, so we can use class members individually for each module.

class Trampoline(torch.autograd.Function):
target_extension = extension
original_module = m
Expand Down Expand Up @@ -83,14 +84,32 @@ def unpatch_model(model, orig_forward_name):


def __make_16bit_traceable(model: torch.nn.Module):
# Replace torch.nn.Linear with ModuleExtension and move other modules to fp32
extensions = {torch.nn.Linear: ModuleExtension(
torch.nn.Linear,
"aten::linear",
evaluate=lambda module, *args, **kwargs: torch.ones(
list(args[0].shape[:-1]) + [module.out_features], dtype=torch.float32) * 0.5,
convert=lambda module, target_op, *args, **kwargs: target_op(args[0], module.weight, module.bias))
"""
Prepare a 16-bit PyTorch model for tracing with OpenVINO.
- Replace known list of modules with ModuleExtension.
- Convert other modules with weights to FP32.
"""
extensions = {
torch.nn.Linear: ModuleExtension(
torch.nn.Linear, "ov_ext::linear",
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape[:-1]) + [module.out_features], 0.5, dtype=torch.float32),
convert=lambda module, target_op, *args, **kwargs: target_op(args[0], module.weight, module.bias)),
torch.nn.Embedding: ModuleExtension(
torch.nn.Embedding, "ov_ext::embedding",
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape) + [module.embedding_dim], 0.5, dtype=torch.float32),
convert=lambda module, target_op, *args, **kwargs: target_op(module.weight, args[0], module.padding_idx, module.scale_grad_by_freq, module.sparse)),
}
try:
from transformers.pytorch_utils import Conv1D
extensions[Conv1D] = ModuleExtension(
Conv1D, "ov_ext::conv1d",
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape[:-1]) + [module.nf], 0.5, dtype=torch.float32),
convert=lambda module, target_op, *args, **kwargs: target_op(args[0], module.weight, module.bias))
except:
pass
patch_model(model, extensions,
"_openvino_module_extension_patch_orig_forward")
for _, module in model.named_modules():
Expand Down
28 changes: 28 additions & 0 deletions src/frontends/pytorch/src/op/addmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scatter_elements_update.hpp"
#include "openvino/op/shape_of.hpp"
#include "utils.hpp"

namespace ov {
Expand Down Expand Up @@ -58,6 +63,29 @@ OutputVector translate_addmm_fx(const NodeContext& context) {
return {translate_addmm_common(context, beta, alpha)};
};

OutputVector translate_conv1d_ext(const NodeContext& context) {
// not really a convolution, implemented based on
// https://github.com/huggingface/transformers/blob/0ed3ffcb4461a244b87781a24e5ebd0a78f98142/src/transformers/pytorch_utils.py#L84
num_inputs_check(context, 3, 3);
auto x = context.get_input(0);
auto weight = context.get_input(1);
weight = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(weight, x));
auto bias = context.get_input(2);
bias = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(bias, x));

auto neg_one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto shape_x = context.mark_node(std::make_shared<v3::ShapeOf>(x, element::i32));
auto x_last_dim = context.mark_node(std::make_shared<v8::Gather>(shape_x, neg_one, zero));
auto x_new_shape = context.mark_node(std::make_shared<v0::Concat>(OutputVector{neg_one, x_last_dim}, 0));

auto x_new = context.mark_node(std::make_shared<v1::Reshape>(x, x_new_shape, false));
auto mm = context.mark_node(std::make_shared<v0::MatMul>(x_new, weight));
auto addmm = context.mark_node(std::make_shared<v1::Add>(bias, mm));
auto size_out = context.mark_node(std::make_shared<v12::ScatterElementsUpdate>(shape_x, neg_one, neg_one, zero));
return {context.mark_node(std::make_shared<v1::Reshape>(addmm, size_out, false))};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
Expand Down
13 changes: 12 additions & 1 deletion src/frontends/pytorch/src/op/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ OutputVector translate_embedding(const NodeContext& context) {
auto data = context.get_input(0);
auto indices = context.get_input(1);
indices = context.mark_node(std::make_shared<ov::op::v0::Convert>(indices, element::i32));
// skip parameters 2, 3, 4 used only during trainig:
// skip parameters 2, 3, 4 used only during training:
// padding_idx - if specified, the entries at padding_idx do not contribute to the gradient
// scale_grad_by_freq - if given, this will scale gradients by the inverse of frequency of
// the words in the mini-batch.
Expand All @@ -29,6 +29,17 @@ OutputVector translate_embedding(const NodeContext& context) {
return {context.mark_node(std::make_shared<ov::op::v8::Gather>(data, indices, axis_0))};
};

OutputVector translate_embedding_ext(const NodeContext& context) {
// used in 16-bit patching
num_inputs_check(context, 2, 5);
auto data = context.get_input(0);
data = context.mark_node(std::make_shared<ov::op::v0::Convert>(data, element::f32));
auto indices = context.get_input(1);
indices = context.mark_node(std::make_shared<ov::op::v0::Convert>(indices, element::i32));
auto axis_0 = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{}, {0}));
return {context.mark_node(std::make_shared<ov::op::v8::Gather>(data, indices, axis_0))};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
Expand Down
5 changes: 5 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ OP_CONVERTER(translate_channel_shuffle);
OP_CONVERTER(translate_clamp);
OP_CONVERTER(translate_constant);
OP_CONVERTER(translate_conv_transposend);
OP_CONVERTER(translate_conv1d_ext);
OP_CONVERTER(translate_convnd);
OP_CONVERTER(translate_convolution);
OP_CONVERTER(translate_convolution_mode);
Expand All @@ -72,6 +73,7 @@ OP_CONVERTER(translate_dot);
OP_CONVERTER(translate_elu);
OP_CONVERTER(translate_embedding);
OP_CONVERTER(translate_embedding_bag);
OP_CONVERTER(translate_embedding_ext);
OP_CONVERTER(translate_empty);
OP_CONVERTER(translate_empty_like);
OP_CONVERTER(translate_erf);
Expand Down Expand Up @@ -702,6 +704,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::zero_", op::inplace_op<op::translate_zeros_like>},
{"aten::zeros", op::translate_zeros},
{"aten::zeros_like", op::translate_zeros_like},
{"ov_ext::embedding", op::translate_embedding_ext},
{"ov_ext::conv1d", op::translate_conv1d_ext},
{"ov_ext::linear", op::translate_linear},
{"prim::Constant", op::translate_constant},
{"prim::device", op::translate_constant},
// prim::DictConstruct - Supported in limited set of patterns
Expand Down
10 changes: 7 additions & 3 deletions tests/layer_tests/py_frontend_tests/test_torch_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,16 +687,20 @@ def test_patched_16bit_model_converts():
from openvino.frontend.pytorch import patch_model
from openvino import convert_model, compile_model
import copy
from transformers.pytorch_utils import Conv1D

class ModelWithLinear(torch.nn.Module):
def __init__(self):
super().__init__()

self.branch1 = torch.nn.Sequential(
torch.nn.Linear(64, 32), torch.nn.ReLU()
torch.nn.Embedding(10, 64),
torch.nn.Linear(64, 32),
torch.nn.ReLU()
)
self.branch2 = torch.nn.Sequential(
torch.nn.Linear(128, 64), torch.nn.ReLU()
Conv1D(256, 128),
torch.nn.Linear(256, 64), torch.nn.ReLU()
)
self.buffer = torch.ones(32)

Expand All @@ -705,7 +709,7 @@ def forward(self, x1, x2):
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)

example = (torch.randn(32, 64), torch.randn(32, 128))
example = (torch.randint(0, 10, [32, 64]), torch.randn(32, 128))
model_ref = ModelWithLinear()
with torch.no_grad():
res_ref = model_ref(*example)
Expand Down

0 comments on commit ead5235

Please sign in to comment.