Skip to content

Commit

Permalink
[TRANSFORMATIONS] Extend PositionIDsReplacer pattern
Browse files Browse the repository at this point in the history
Extend PositionIDsReplacer pattern to support more models:

- facebook/opt-350m
- core42/jais-13b

Signed-off-by: Andrii Staikov [email protected]

Tickets:
CVS-143065
  • Loading branch information
CuriousPanCake committed Jun 6, 2024
1 parent 9e66235 commit 5812c6e
Showing 1 changed file with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

#include "openvino/cc/pass/itt.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "transformations/utils/utils.hpp"

using namespace ov::op;
Expand All @@ -26,7 +28,9 @@ ov::pass::PositionIDsReplacer::PositionIDsReplacer(const Output<Node>& position_
auto convert = pattern::wrap_type<v0::Convert>({add_offset});
auto position_embed = pattern::wrap_type<v8::Gather>({pattern::any_input(), convert, pattern::any_input()});

auto add = pattern::wrap_type<v1::Add>({input_embed, position_embed});
auto mul = pattern::optional<v0::MatMul>({input_embed, pattern::any_input()});

auto add = pattern::wrap_type<v1::Add>({mul, position_embed});

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
Expand Down

0 comments on commit 5812c6e

Please sign in to comment.