Skip to content

Commit

Permalink
update rope patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Jul 10, 2024
1 parent 8b95ed6 commit dd69cb4
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1102,21 +1102,11 @@ inline std::shared_ptr<Node> operator|(const std::shared_ptr<Node>& lhs, const s
OutputVector{lhs->get_default_output(), rhs->get_default_output()});
}

inline std::shared_ptr<Node> GenSlice2(detail::PatternNode data,
detail::PatternNode start,
detail::PatternNode stop,
detail::PatternNode step,
size_t axis,
bool single_axis = false) {
std::shared_ptr<Node> opt1;
if (single_axis) {
opt1 = makePattern<opset8::Slice>({data, start, stop, step, Symbol(axis)});
} else {
std::vector<Symbol> axes(axis + 1);
std::iota(axes.begin(), axes.end(), 0);
opt1 = makePattern<opset8::Slice>({data, start, stop, step, axes});
}

inline std::shared_ptr<Node> GenStridedSlice(detail::PatternNode data,
detail::PatternNode start,
detail::PatternNode stop,
detail::PatternNode step,
size_t axis) {
std::vector<int64_t> begin_mask(axis + 1, 1);
std::vector<int64_t> end_mask(axis + 1, 1);
std::vector<int64_t> new_axis_mask;
Expand All @@ -1132,7 +1122,7 @@ inline std::shared_ptr<Node> GenSlice2(detail::PatternNode data,
{"new_axis_mask", new_axis_mask},
{"shrink_axis_mask", shrink_axis_mask},
{"ellipsis_mask", ellipsis_mask}});
return opt1 | opt2;
return opt2;
}

inline std::shared_ptr<Node> GenSlice(detail::PatternNode data, Symbol start, Symbol stop, Symbol step, size_t axis) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ ov::pass::RoPEFusionCosSinPreprocess::RoPEFusionCosSinPreprocess() {
auto gather_positions = makePattern("i32[?,?,?,?]");

auto prepare_cos_sin_gptneox = [&](std::shared_ptr<Node> const_tab) {
auto slice = GenSlice2(const_tab, {0}, node_batch_size, {1}, 0);
return makePattern<opset6::GatherElements>({slice, gather_positions}, {{"axis", 2}});
auto slice = GenStridedSlice(const_tab, {0}, node_batch_size, {1}, 0);
auto strided_slice = GenStridedSlice(const_tab, {0}, node_batch_size, {1}, 0);
return makePattern<opset6::GatherElements>({strided_slice | slice, gather_positions}, {{"axis", 2}});
};

auto seq_len = makePattern("i32[1]");
Expand All @@ -130,13 +131,16 @@ ov::pass::RoPEFusionCosSinPreprocess::RoPEFusionCosSinPreprocess() {
auto head_dims = ov::gen_pattern::Symbol("head_dims");
auto prepare_cos_sin_llama = [&](std::shared_ptr<Node> const_tab) {
auto ScatterUpdate = makePattern<opset3::ScatterUpdate>({{0, 0, 0}, 2, seq_len, 0});
auto slice_Slice = GenSlice2(const_tab, {0, 0, 0}, ScatterUpdate, {1, 1, 1}, 2);
auto squeeze = makePattern<opset1::Reshape>({slice_Slice, {-1, head_dims}});
auto slice_Slice = makePattern<ov::opset8::Slice>({const_tab, {0}, seq_len, {1}, {2}});
auto slice_StridedSlice = GenStridedSlice(const_tab, {0, 0, 0}, ScatterUpdate, {1, 1, 1}, 2);
auto squeeze = makePattern<opset1::Reshape>({slice_StridedSlice | slice_Slice, {-1, head_dims}});
auto index_Gather = makePattern<opset8::Gather>({squeeze, gather_positions_2d, 0}, {{"batch_dims", 0}});

// another simplified pattern for gathering at position_ids
auto slice_Slice2 = GenSlice2(const_tab, {0}, seq_len, {1}, 0);
auto index_Gather2 = makePattern<opset8::Gather>({slice_Slice2, gather_positions_2d, 0}, {{"batch_dims", 0}});
auto slice_Slice2 = makePattern<ov::opset8::Slice>({const_tab, {0}, seq_len, {1}, {0}});
auto slice_StridedSlice2 = GenStridedSlice(const_tab, {0}, seq_len, {1}, 0);
auto index_Gather2 = makePattern<opset8::Gather>({slice_Slice2 | slice_StridedSlice2, gather_positions_2d, 0},
{{"batch_dims", 0}});

auto unsqueeze = makePattern<opset1::Reshape>({index_Gather | index_Gather2, {1, 1, -1, head_dims}});
auto unsqueeze2 = makePattern<opset1::Unsqueeze>({index_Gather2, 1});
Expand Down Expand Up @@ -454,14 +458,15 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id) {

auto x_even = makePattern<opset8::Gather>({reshape_Reshape_453, 0, -1}, {{"batch_dims", 0}});
auto x_odd = makePattern<opset8::Gather>({reshape_Reshape_453, 1, -1}, {{"batch_dims", 0}});

auto slice_Slice_449 = GenSlice2(cos_sin_cache, {0}, seq_length, {1}, 0);
auto slice_Slice_449 = makePattern<ov::opset8::Slice>({cos_sin_cache, {0}, seq_length, {1}, {0}});
auto slice_StridedSlice_449 = GenStridedSlice(cos_sin_cache, {0}, seq_length, {1}, 0);
auto var_split_2 = makePattern<opset1::VariadicSplit>({cos_sin_cache, 0, {0, ov::gen_pattern::Symbol("end")}});
var_split_2->set_output_size(2);

auto view_Reshape_460 = makePattern<opset1::Reshape>(
{slice_Slice_449 | var_split_2->output(0), ListConstruct_379_Concat | const_target_shape_2},
{{"special_zero", false}});
auto view_Reshape_460 =
makePattern<opset1::Reshape>({slice_StridedSlice_449 | slice_Slice_449 | var_split_2->output(0),
ListConstruct_379_Concat | const_target_shape_2},
{{"special_zero", false}});

auto cos_tab = makePattern<opset8::Gather>({view_Reshape_460, 0, -1}, {{"batch_dims", 0}});
auto x_even_cos = makePattern<opset1::Multiply>({x_even, cos_tab}, {{"auto_broadcast", "numpy"}});
Expand Down Expand Up @@ -568,14 +573,15 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
auto neg_Multiply = makePattern<opset1::Multiply>({Gather_311651, {-1}}, {{"auto_broadcast", "numpy"}});

auto ScatterUpdate_463814 = makePattern<opset3::ScatterUpdate>({{0, 0}, {1}, Gather_377635 | neg_Multiply, {0}});
auto slice_Slice_446 = GenSlice2(rotary_emb_cos,
ScatterUpdate_463814 | Gather_377635 | neg_Multiply,
{INT_MAX},
{1},
1,
true); // tensor_array<f32[1,..4096,1,128]>
auto slice_Slice_446 =
makePattern<ov::opset8::Slice>({rotary_emb_cos, Gather_377635 | neg_Multiply, {INT_MAX}, {1}, {1}});
auto slice_StridedSlice_446 = GenStridedSlice(rotary_emb_cos,
ScatterUpdate_463814,
{0, INT_MAX},
{1, 1},
1); // tensor_array<f32[1,..4096,1,128]>
auto mul_Multiply_552 =
makePattern<opset1::Multiply>({slice_Slice_543, slice_Slice_446},
makePattern<opset1::Multiply>({slice_Slice_543, slice_StridedSlice_446 | slice_Slice_446},
{{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,?,32,128]>

auto reshape_opt1 = [&](std::shared_ptr<Node> input_BLHS) {
Expand Down Expand Up @@ -612,14 +618,15 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
makePattern<opset1::Squeeze>({ListUnpack_586_Split->output(0), -2}); // tensor_array<f32[?,?,32,64]>
auto cat_Concat_593 = makePattern<opset1::Concat>({ListUnpack_586_Squeeze_0, ListUnpack_586_Squeeze},
{{"axis", -1}}); // tensor_array<f32[?,?,32,128]>
auto slice_Slice_470 = GenSlice2(rotary_emb_sin,
ScatterUpdate_463814 | Gather_377635 | neg_Multiply,
{INT_MAX},
{1},
1,
true); // tensor_array<f32[1,..4096,1,128]>
auto slice_StridedSlice_470 = GenStridedSlice(rotary_emb_sin,
ScatterUpdate_463814,
{0, INT_MAX},
{1, 1},
1); // tensor_array<f32[1,..4096,1,128]>
auto slice_Slice_470 =
makePattern<opset8::Slice>({rotary_emb_sin, Gather_377635 | neg_Multiply, {INT_MAX}, {1}, {1}});
auto mul_Multiply_594 =
makePattern<opset1::Multiply>({cat_Concat_593, slice_Slice_470},
makePattern<opset1::Multiply>({cat_Concat_593, slice_StridedSlice_470 | slice_Slice_470},
{{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,?,32,128]>
auto add_Add_597 = makePattern<opset1::Add>({mul_Multiply_552, mul_Multiply_594},
{{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,?,32,128]>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,15 @@ CausalMaskPreprocess::CausalMaskPreprocess() {
makePattern<ov::opset8::Gather>({ShapeOf_49034, {1}, 0}, {{"batch_dims", 0}}); // tensor_array<i32[1]>
auto ScatterUpdate_93502 =
makePattern<ov::opset3::ScatterUpdate>({{0, 0, 0, 0}, {3}, Gather_41642, {0}}); // tensor_array<i32[4]>
auto SliceAssign_201_Slice = GenSlice2(SliceAssign_201_Reshape, {0, 0, 0, 0}, ScatterUpdate_93502, {1, 1, 1, 1}, 3); // tensor_array<i32[?,1,8192,..8192]>
auto SliceAssign_201_Slice = makePattern<ov::opset8::Slice>({SliceAssign_201_Reshape, {0}, Gather_41642, {1}, {3}});
auto SliceAssign_201_StridedSlice = GenStridedSlice(SliceAssign_201_Reshape, {0, 0, 0, 0},
ScatterUpdate_93502, {1, 1, 1, 1}, 3); // tensor_array<i32[?,1,8192,..8192]>
auto SliceAssign_201_Reshape_1 =
makePattern<ov::opset1::Reshape>({SliceAssign_201_Slice, {-1, 1}},
makePattern<ov::opset1::Reshape>({SliceAssign_201_Slice | SliceAssign_201_StridedSlice, {-1, 1}},
{{"special_zero", false}}); // tensor_array<i32[?,1]>
auto causal_mask_boolean_1 = GenSlice2(mul_Multiply_1, {0, 0, 0, 0}, ScatterUpdate_93502, {1, 1, 1, 1}, 3); // tensor_array<f32[?,1,8192,..8192]>
auto causal_mask_boolean_slice = makePattern<ov::opset8::Slice>({mul_Multiply_1, {0}, Gather_41642, {1}, {3}});
auto causal_mask_boolean_strided_slice = GenStridedSlice(mul_Multiply_1, {0, 0, 0, 0},
ScatterUpdate_93502, {1, 1, 1, 1}, 3); // tensor_array<f32[?,1,8192,..8192]>
auto Constant_107278 = makeConst(ov::element::f32,
ov::Shape({
1,
Expand All @@ -139,7 +143,7 @@ CausalMaskPreprocess::CausalMaskPreprocess() {
}),
{0.000000f});
auto eq_Equal =
makePattern<ov::opset1::Equal>({causal_mask_boolean_1, Constant_107278},
makePattern<ov::opset1::Equal>({causal_mask_boolean_slice | causal_mask_boolean_strided_slice, Constant_107278},
{{"auto_broadcast", "numpy"}}); // tensor_array<u8[?,1,8192,..8192]>
auto unsqueeze_Unsqueeze_1 =
makePattern<ov::opset1::Unsqueeze>({attention_mask, {1, 2}}); // tensor_array<i32[?,1,1,?]>
Expand All @@ -159,9 +163,11 @@ CausalMaskPreprocess::CausalMaskPreprocess() {
makePattern<ov::opset1::LogicalAnd>({eq_Equal, eq_Equal_1},
{{"auto_broadcast", "numpy"}}); // tensor_array<u8[?,1,8192,?]>
auto masked_fill_Select =
makePattern<ov::opset1::Select>({mul_LogicalAnd, -FLT_MAX, causal_mask_boolean_1},
makePattern<ov::opset1::Select>({mul_LogicalAnd, -FLT_MAX,
causal_mask_boolean_slice | causal_mask_boolean_strided_slice},
{{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,1,8192,?]>
auto copy__ShapeOf = makePattern<ov::opset1::ShapeOf>({causal_mask_boolean_1}); // tensor_array<i32[4]>
auto copy__ShapeOf = makePattern<ov::opset1::ShapeOf>({causal_mask_boolean_slice |
causal_mask_boolean_strided_slice}); // tensor_array<i32[4]>
auto Constant_47319 = makeConst(ov::element::u8, ov::Shape({}), {0});
auto copy__Broadcast =
makePattern<ov::opset1::Broadcast>({masked_fill_Select, copy__ShapeOf, Constant_47319},
Expand All @@ -175,8 +181,10 @@ CausalMaskPreprocess::CausalMaskPreprocess() {
{{"special_zero", true}}); // tensor_array<f32[?,1,8192,8192]>
auto ScatterUpdate_93554 =
makePattern<ov::opset3::ScatterUpdate>({{0, 0, 0, 0}, {3}, kvLen, {0}}); // tensor_array<i32[4]>
auto slice_Slice_14 = GenSlice2(SliceAssign_201_Reshape_3, {0, 0, 0, 0}, ScatterUpdate_93554, {1, 1, 1, 1}, 3); // tensor_array<f32[?,1,8192,..8192]>
auto index_Gather = makePattern<ov::opset8::Gather>({slice_Slice_14, cache_positions, 2},
auto slice_StridedSlice_14 = GenStridedSlice(SliceAssign_201_Reshape_3, {0, 0, 0, 0},
ScatterUpdate_93554, {1, 1, 1, 1}, 3); // tensor_array<f32[?,1,8192,..8192]>
auto slice_Slice_14 = makePattern<ov::opset8::Slice>({SliceAssign_201_Reshape_3, {0}, kvLen, {1}, {3}});
auto index_Gather = makePattern<ov::opset8::Gather>({slice_Slice_14 | slice_StridedSlice_14, cache_positions, 2},
{{"batch_dims", 0}},
nullptr); // tensor_array<f32[?,1,?,..8192]>
auto result = index_Gather;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ void Transformations::PostLpt() {
// Execute before snippets. Otherwise FQ will be converted to Subgraph
CPU_REGISTER_PASS_X64(postLPTPassManager, ConvertFqRnnToQuantizedRnn);

postLPTPassManager.register_pass<ov::pass::Serialize>("ser.xml", "ser.bin");
CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RoPEFusion);
CPU_REGISTER_PASS_ARM64(postLPTPassManager, ov::pass::RoPEFusion);
CPU_REGISTER_PASS_X64(postLPTPassManager, CausalMaskPreprocessFusion);
Expand Down

0 comments on commit dd69cb4

Please sign in to comment.