Skip to content

Commit

Permalink
[CPU]Fix GPT-J RoPE fusion (openvinotoolkit#23519)
Browse files Browse the repository at this point in the history
### Details:
 - *Support new RoPE pattern of GPT-J*
- *Local test shows 17 % improvement for 2nd token latency for BF16 in
`Intel(R) Xeon(R) Platinum 8468`*

### Tickets:
 - *CVS-134949*
  • Loading branch information
zhangYiIntel authored and bbielawx committed Apr 12, 2024
1 parent 63e1312 commit e301029
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,11 @@ ov::intel_cpu::RoPEFusionGPTJ::RoPEFusionGPTJ() {

auto varsplit = makePattern<opset1::VariadicSplit>({gather_sin_cos, -1, {ndims / 2, -1}});
varsplit->set_output_size(2);
auto unsqueeze_sin = makePattern<opset1::Reshape>({varsplit->output(0), {1, -1, 1, 32}});
auto unsqueeze_cos = makePattern<opset1::Reshape>({varsplit->output(1), {1, -1, 1, 32}});
// Reshape or UnSqueeze should both be support
auto unsqueeze_sin = makePattern<opset1::Reshape>({varsplit->output(0), {1, -1, 1, 32}}) |
makePattern<opset1::Unsqueeze>({varsplit->output(0), 2});
auto unsqueeze_cos = makePattern<opset1::Reshape>({varsplit->output(1), {1, -1, 1, 32}}) |
makePattern<opset1::Unsqueeze>({varsplit->output(1), 2});
// repeate cos/sin table
auto const_idx = makeConst(ov::element::i32, ov::PartialShape::dynamic(), [](const ov::op::v0::Constant& node) {
const auto& vec = node.get_vector<int32_t>();
Expand All @@ -402,9 +405,6 @@ ov::intel_cpu::RoPEFusionGPTJ::RoPEFusionGPTJ() {
auto repeat_interleave_sin = makePattern<opset8::Gather>({unsqueeze_sin, const_idx, 3}, {{"batch_dims", 0}});
auto repeat_interleave_cos = makePattern<opset8::Gather>({unsqueeze_cos, const_idx, 3}, {{"batch_dims", 0}});

auto t_cos = makePattern(ov::Rank(4));
auto t_sin = makePattern(ov::Rank(4));

// x interleave (-x[:,:,:, 1::2], x[:,:,:, 0::2])
auto slice_Slice_1174 = GenSlice(slice_Slice_965, 1, int32_max, 2, 3);

Expand All @@ -418,13 +418,16 @@ ov::intel_cpu::RoPEFusionGPTJ::RoPEFusionGPTJ() {
auto ShapeOf_169068 = makePattern<opset1::ShapeOf>({stack_1182});
auto flatten_Slice_1194 = GenSlice(ShapeOf_169068, 0, 3, 1, 0);
auto flatten_Concat_1197 = makePattern<opset1::Concat>({flatten_Slice_1194, {-1}}, {{"axis", 0}});
// If with special zero, no need to use shapeof to get full shape
auto flatten_Reshape_1198 = makePattern<opset1::Reshape>({stack_1182, flatten_Concat_1197});
auto flatten_Reshape_Zero =
makePattern<opset1::Reshape>({stack_1182, ov::pass::pattern::any_input()}, {{"special_zero", true}});

// x*cos [B,L,H,ndims]
auto mul_cos =
makePattern<opset1::Multiply>({slice_Slice_965, repeat_interleave_cos}, {{"auto_broadcast", "numpy"}});
auto mul_sin =
makePattern<opset1::Multiply>({flatten_Reshape_1198, repeat_interleave_sin}, {{"auto_broadcast", "numpy"}});
makePattern<opset1::Multiply>({flatten_Reshape_1198 | flatten_Reshape_Zero, repeat_interleave_sin}, {{"auto_broadcast", "numpy"}});

// *cos + *sin
auto rotary_emb = makePattern<opset1::Add>({mul_cos, mul_sin}, {{"auto_broadcast", "numpy"}});
Expand Down Expand Up @@ -460,22 +463,30 @@ ov::intel_cpu::RoPEFusionGPTJ::RoPEFusionGPTJ() {
auto new_node = std::make_shared<RoPENode>(new_args, config);
new_node->set_friendly_name(old_node->get_friendly_name());
ov::copy_runtime_info({pattern_map.at(varsplit).get_node_shared_ptr(),
pattern_map.at(unsqueeze_sin).get_node_shared_ptr(),
pattern_map.at(unsqueeze_cos).get_node_shared_ptr(),
pattern_map.at(repeat_interleave_sin).get_node_shared_ptr(),
pattern_map.at(repeat_interleave_cos).get_node_shared_ptr(),
pattern_map.at(neg_Multiply_1177).get_node_shared_ptr(),
pattern_map.at(Unsqueeze_65524).get_node_shared_ptr(),
pattern_map.at(Unsqueeze_65525).get_node_shared_ptr(),
pattern_map.at(stack_1182).get_node_shared_ptr(),
pattern_map.at(flatten_Concat_1197).get_node_shared_ptr(),
pattern_map.at(mul_cos).get_node_shared_ptr(),
pattern_map.at(mul_sin).get_node_shared_ptr(),
pattern_map.at(rotary_emb).get_node_shared_ptr(),
pattern_map.at(cat_Concat_1211).get_node_shared_ptr(),
pattern_map.at(permute_Transpose_1213).get_node_shared_ptr()},
new_node);
ov::replace_node(old_node, new_node);
// shapeof may be moved up from transpose to add,
// After RoPE fusion, shapeof must be moved to the data input of RoPE otherwise extra subgraph exists
std::shared_ptr<ov::Node> rotary_emb_node = pattern_map.at(rotary_emb).get_node_shared_ptr();
auto rotary_emb_out = rotary_emb_node->output(0);
if (rotary_emb_out.get_target_inputs().size() == 2) {
for (auto& input : rotary_emb_out.get_target_inputs()) {
if (ov::is_type<opset1::ShapeOf>(input.get_node())) {
input.replace_source_output(pattern_map.at(view_Reshape));
}
}
}
return true;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,5 +457,148 @@ TEST_F(RoPECPUTestQwen7b, smoke_CompareWithRefs) {
CheckNumberOfNodesWithType(compiledModel, "RoPE", 1);
}

class RoPECPUTestGPTJ : public SubgraphBaseTest, public testing::WithParamInterface<bool> {
public:
static std::string getTestCaseName(const testing::TestParamInfo<bool>& obj) {
bool hasShapeOf;
hasShapeOf = obj.param;
std::ostringstream result;
result << "hasShapeOf=" << hasShapeOf << std::endl;
return result.str();
}
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
const auto& funcInputs = function->inputs();

auto& input_shape = targetInputStaticShapes[0];
auto& sincos_shape = targetInputStaticShapes[1];
ov::Tensor t_input =
utils::create_and_fill_tensor(funcInputs[0].get_element_type(), input_shape, 2, -1.0f, 32768);
ov::Tensor t_cos_sin_cache =
utils::create_and_fill_tensor(funcInputs[1].get_element_type(), sincos_shape, 2, -1.0f, 32768);

inputs.clear();
inputs.insert({funcInputs[0].get_node_shared_ptr(), t_input});
inputs.insert({funcInputs[1].get_node_shared_ptr(), t_cos_sin_cache});
}

protected:
std::shared_ptr<ov::Model> buildROPE_GPTJ(const int num_head,
const int hidden_dims,
const int rotary_dims,
bool hasShapeOf) {
auto int32_max = std::numeric_limits<std::int32_t>::max();
auto input =
std::make_shared<ov::opset1::Parameter>(ov::element::f32, PartialShape{-1, -1, num_head, hidden_dims});
auto sincos = std::make_shared<ov::opset1::Parameter>(ov::element::f32, PartialShape{-1, -1, rotary_dims});

auto slice_Slice_965 =
makeOP<ov::op::v1::StridedSlice>({input, {0, 0, 0, 0}, {0, 0, 0, rotary_dims}, {1, 1, 1, 1}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
slice_Slice_965->set_friendly_name("slice_Slice_965");

auto varsplit = makeOP<ov::op::v1::VariadicSplit>({sincos, -1, {rotary_dims / 2, -1}});
varsplit->set_output_size(2);
varsplit->set_friendly_name("varsplit");
auto unsqueeze_sin = makeOP<opset1::Unsqueeze>({varsplit->output(0), 2});
auto unsqueeze_cos = makeOP<opset1::Unsqueeze>({varsplit->output(1), 2});
std::vector<int32_t> gather_idx(rotary_dims, 1);
int32_t v = 0;
for (size_t i = 0; i < gather_idx.size(); i += 2, v++) {
gather_idx[i] = v;
gather_idx[i + 1] = v;
}

auto const_idx = makeConst(ov::element::i32, ov::Shape({static_cast<size_t>(rotary_dims)}), gather_idx);
auto constant_155588 = makeConst(element::f32,
ov::Shape({
1,
1,
1,
1,
}),
{-1.000000f});
auto repeat_interleave_sin = makeOP<opset8::Gather>({unsqueeze_sin, const_idx, 3}, {{"batch_dims", 0}});
auto repeat_interleave_cos = makeOP<opset8::Gather>({unsqueeze_cos, const_idx, 3}, {{"batch_dims", 0}});
repeat_interleave_sin->set_friendly_name("repeat_interleave_sin");
repeat_interleave_cos->set_friendly_name("repeat_interleave_cos");
// x interleave (-x[:,:,:, 1::2], x[:,:,:, 0::2])
auto slice_Slice_1174 =
makeOP<ov::op::v1::StridedSlice>({slice_Slice_965, {0, 0, 0, 1}, {0, 0, 0, int32_max}, {1, 1, 1, 2}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto neg_Multiply_1177 =
makeOP<opset1::Multiply>({slice_Slice_1174, constant_155588}, {{"auto_broadcast", "numpy"}});
auto Unsqueeze_65524 = makeOP<opset1::Unsqueeze>({neg_Multiply_1177, -1});

auto slice_Slice_1168 =
makeOP<ov::op::v1::StridedSlice>({slice_Slice_965, {0, 0, 0, 0}, {0, 0, 0, int32_max}, {1, 1, 1, 2}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto Unsqueeze_65525 = makeOP<opset1::Unsqueeze>({slice_Slice_1168, -1});
auto stack_1182 = makeOP<opset1::Concat>({Unsqueeze_65524, Unsqueeze_65525}, {{"axis", -1}});
auto flatten_Reshape_1198 =
makeOP<opset1::Reshape>({stack_1182, {0, 0, num_head, rotary_dims}}, {{"special_zero", true}});
// x*cos [B,L,H,ndims]
auto mul_cos =
makeOP<opset1::Multiply>({slice_Slice_965, repeat_interleave_cos}, {{"auto_broadcast", "numpy"}});
mul_cos->set_friendly_name("mul_cos");
auto mul_sin =
makeOP<opset1::Multiply>({flatten_Reshape_1198, repeat_interleave_sin}, {{"auto_broadcast", "numpy"}});
// *cos + *sin
auto rotary_emb = makeOP<opset1::Add>({mul_cos, mul_sin}, {{"auto_broadcast", "numpy"}});

auto slice_Slice_971 =
makeOP<ov::op::v1::StridedSlice>({input, {0, 0, 0, rotary_dims}, {0, 0, 0, int32_max}, {1, 1, 1, 1}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto cat_Concat_1211 = makeOP<opset1::Concat>({rotary_emb, slice_Slice_971}, {{"axis", -1}});
auto permute_Transpose_1213 = makeOP<opset1::Transpose>({cat_Concat_1211, {0, 2, 1, 3}});
ov::NodeVector model_output = {permute_Transpose_1213};
if (hasShapeOf) {
auto shapeOf = makeOP<opset1::ShapeOf>({rotary_emb}, {{"output_type", "i32"}});
auto gather = makeOP<opset8::Gather>({shapeOf, {1}, 0}, {{"batch_dims", 0}});
model_output.push_back(gather);
}
return std::make_shared<ov::Model>(model_output, ov::ParameterVector{input, sincos});
}
void SetUp() override {
targetDevice = ov::test::utils::DEVICE_CPU;
bool hasShapeOf = this->GetParam();
const int batch = 2;
const int seq_length = 7;
const int num_head = 16;
const int hidden_dims = 256;
const int rotary_dims = 64;

InputShape input = {{batch, seq_length, num_head, hidden_dims}, {{batch, seq_length, num_head, hidden_dims}}};
InputShape sincos = {{batch, seq_length, rotary_dims}, {{batch, seq_length, rotary_dims}}};
init_input_shapes({input, sincos});
function = buildROPE_GPTJ(num_head, hidden_dims, rotary_dims, hasShapeOf);
}
};

TEST_P(RoPECPUTestGPTJ, smoke_CompareWithRefs) {
run();
CheckNumberOfNodesWithType(compiledModel, "RoPE", 1);
}

INSTANTIATE_TEST_SUITE_P(smoke_RoPECPUTestGPTJ,
RoPECPUTestGPTJ,
::testing::Values(true, false),
RoPECPUTestGPTJ::getTestCaseName);

} // namespace test
} // namespace ov

0 comments on commit e301029

Please sign in to comment.